diff --git a/sklearn_extra/tests/test_common.py b/sklearn_extra/tests/test_common.py index 6563d42b..587b8249 100644 --- a/sklearn_extra/tests/test_common.py +++ b/sklearn_extra/tests/test_common.py @@ -1,20 +1,23 @@ import pytest - -from sklearn.utils.estimator_checks import check_estimator +from sklearn.utils import estimator_checks from sklearn_extra.kernel_approximation import Fastfood -from sklearn_extra.kernel_methods import _eigenpro +from sklearn_extra.kernel_methods import EigenProClassifier, EigenProRegressor from sklearn_extra.cluster import KMedoids +ALL_ESTIMATORS = [Fastfood, KMedoids, EigenProClassifier, EigenProRegressor] + +if hasattr(estimator_checks, "parametrize_with_checks"): + # Common tests are only run on scikit-learn 0.22+ + + @estimator_checks.parametrize_with_checks(ALL_ESTIMATORS) + def test_all_estimators(estimator, check, request): + # TODO: fix this common test failure cf #41 + if isinstance( + estimator, EigenProClassifier + ) and "function check_classifier_multioutput" in str(check): + request.applymarker( + pytest.mark.xfail(run=False, reason="See issue #41") + ) -@pytest.mark.parametrize( - "Estimator", - [ - Fastfood, - KMedoids, - _eigenpro.EigenProClassifier, - _eigenpro.EigenProRegressor, - ], -) -def test_all_estimators(Estimator, request): - return check_estimator(Estimator) + return check(estimator)