Skip to content

Commit 6de55b3

Browse files
MAINT|API Clean up deprecations for 1.6: SAMME.R in AdaBoost and deprecate algorithm (scikit-learn#29997)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 35f106c commit 6de55b3

File tree

10 files changed

+95
-273
lines changed

10 files changed

+95
-273
lines changed

benchmarks/bench_20newsgroups.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"extra_trees": ExtraTreesClassifier(max_features="sqrt", min_samples_split=10),
2222
"logistic_regression": LogisticRegression(),
2323
"naive_bayes": MultinomialNB(),
24-
"adaboost": AdaBoostClassifier(n_estimators=10, algorithm="SAMME"),
24+
"adaboost": AdaBoostClassifier(n_estimators=10),
2525
}
2626

2727

doc/modules/ensemble.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,7 @@ learners::
17091709
>>> from sklearn.ensemble import AdaBoostClassifier
17101710

17111711
>>> X, y = load_iris(return_X_y=True)
1712-
>>> clf = AdaBoostClassifier(n_estimators=100, algorithm="SAMME",)
1712+
>>> clf = AdaBoostClassifier(n_estimators=100)
17131713
>>> scores = cross_val_score(clf, X, y, cv=5)
17141714
>>> scores.mean()
17151715
0.9...

doc/whats_new/v1.6.rst

+4
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ Changelog
251251
right child node as the tree is traversed.
252252
:pr:`28268` by :user:`Adam Li <adam2392>`.
253253

254+
- |API| The parameter `algorithm` of :class:`ensemble.AdaBoostClassifier` is deprecated
255+
and will be removed in 1.8.
256+
:pr:`29997` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
257+
254258
:mod:`sklearn.impute`
255259
.....................
256260

examples/classification/plot_classifier_comparison.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
max_depth=5, n_estimators=10, max_features=1, random_state=42
6565
),
6666
MLPClassifier(alpha=1, max_iter=1000, random_state=42),
67-
AdaBoostClassifier(algorithm="SAMME", random_state=42),
67+
AdaBoostClassifier(random_state=42),
6868
GaussianNB(),
6969
QuadraticDiscriminantAnalysis(),
7070
]

examples/ensemble/plot_adaboost_multiclass.py

-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
adaboost_clf = AdaBoostClassifier(
8181
estimator=weak_learner,
8282
n_estimators=n_estimators,
83-
algorithm="SAMME",
8483
random_state=42,
8584
).fit(X_train, y_train)
8685

examples/ensemble/plot_adaboost_twoclass.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@
3939
y = np.concatenate((y1, -y2 + 1))
4040

4141
# Create and fit an AdaBoosted decision tree
42-
bdt = AdaBoostClassifier(
43-
DecisionTreeClassifier(max_depth=1), algorithm="SAMME", n_estimators=200
44-
)
45-
42+
bdt = AdaBoostClassifier(DecisionTreeClassifier(max_depth=1), n_estimators=200)
4643
bdt.fit(X, y)
4744

4845
plot_colors = "br"

examples/ensemble/plot_forest_iris.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,7 @@
7474
DecisionTreeClassifier(max_depth=None),
7575
RandomForestClassifier(n_estimators=n_estimators),
7676
ExtraTreesClassifier(n_estimators=n_estimators),
77-
AdaBoostClassifier(
78-
DecisionTreeClassifier(max_depth=3),
79-
n_estimators=n_estimators,
80-
algorithm="SAMME",
81-
),
77+
AdaBoostClassifier(DecisionTreeClassifier(max_depth=3), n_estimators=n_estimators),
8278
]
8379

8480
for pair in ([0, 1], [0, 2], [2, 3]):

sklearn/ensemble/_weight_boosting.py

+31-139
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from numbers import Integral, Real
2525

2626
import numpy as np
27-
from scipy.special import xlogy
2827

2928
from ..base import (
3029
ClassifierMixin,
@@ -36,7 +35,7 @@
3635
from ..metrics import accuracy_score, r2_score
3736
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
3837
from ..utils import _safe_indexing, check_random_state
39-
from ..utils._param_validation import HasMethods, Interval, StrOptions
38+
from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions
4039
from ..utils.extmath import softmax, stable_cumsum
4140
from ..utils.metadata_routing import (
4241
_raise_for_unsupported_routing,
@@ -375,16 +374,12 @@ class AdaBoostClassifier(
375374
a trade-off between the `learning_rate` and `n_estimators` parameters.
376375
Values must be in the range `(0.0, inf)`.
377376
378-
algorithm : {'SAMME', 'SAMME.R'}, default='SAMME.R'
379-
If 'SAMME.R' then use the SAMME.R real boosting algorithm.
380-
``estimator`` must support calculation of class probabilities.
381-
If 'SAMME' then use the SAMME discrete boosting algorithm.
382-
The SAMME.R algorithm typically converges faster than SAMME,
383-
achieving a lower test error with fewer boosting iterations.
377+
algorithm : {'SAMME'}, default='SAMME'
378+
Use the SAMME discrete boosting algorithm.
384379
385-
.. deprecated:: 1.4
386-
`"SAMME.R"` is deprecated and will be removed in version 1.6.
387-
'"SAMME"' will become the default.
380+
.. deprecated:: 1.6
381+
`algorithm` is deprecated and will be removed in version 1.8. This
382+
estimator only implements the 'SAMME' algorithm.
388383
389384
random_state : int, RandomState instance or None, default=None
390385
Controls the random seed given at each `estimator` at each
@@ -470,9 +465,9 @@ class AdaBoostClassifier(
470465
>>> X, y = make_classification(n_samples=1000, n_features=4,
471466
... n_informative=2, n_redundant=0,
472467
... random_state=0, shuffle=False)
473-
>>> clf = AdaBoostClassifier(n_estimators=100, algorithm="SAMME", random_state=0)
468+
>>> clf = AdaBoostClassifier(n_estimators=100, random_state=0)
474469
>>> clf.fit(X, y)
475-
AdaBoostClassifier(algorithm='SAMME', n_estimators=100, random_state=0)
470+
AdaBoostClassifier(n_estimators=100, random_state=0)
476471
>>> clf.predict([[0, 0, 0, 0]])
477472
array([1])
478473
>>> clf.score(X, y)
@@ -487,23 +482,19 @@ class AdaBoostClassifier(
487482
refer to :ref:`sphx_glr_auto_examples_ensemble_plot_adaboost_twoclass.py`.
488483
"""
489484

490-
# TODO(1.6): Modify _parameter_constraints for "algorithm" to only check
491-
# for "SAMME"
485+
# TODO(1.8): remove "algorithm" entry
492486
_parameter_constraints: dict = {
493487
**BaseWeightBoosting._parameter_constraints,
494-
"algorithm": [
495-
StrOptions({"SAMME", "SAMME.R"}),
496-
],
488+
"algorithm": [StrOptions({"SAMME"}), Hidden(StrOptions({"deprecated"}))],
497489
}
498490

499-
# TODO(1.6): Change default "algorithm" value to "SAMME"
500491
def __init__(
501492
self,
502493
estimator=None,
503494
*,
504495
n_estimators=50,
505496
learning_rate=1.0,
506-
algorithm="SAMME.R",
497+
algorithm="deprecated",
507498
random_state=None,
508499
):
509500
super().__init__(
@@ -519,43 +510,23 @@ def _validate_estimator(self):
519510
"""Check the estimator and set the estimator_ attribute."""
520511
super()._validate_estimator(default=DecisionTreeClassifier(max_depth=1))
521512

522-
# TODO(1.6): Remove, as "SAMME.R" value for "algorithm" param will be
523-
# removed in 1.6
524-
# SAMME-R requires predict_proba-enabled base estimators
525-
if self.algorithm != "SAMME":
513+
if self.algorithm != "deprecated":
526514
warnings.warn(
527-
(
528-
"The SAMME.R algorithm (the default) is deprecated and will be"
529-
" removed in 1.6. Use the SAMME algorithm to circumvent this"
530-
" warning."
531-
),
515+
"The parameter 'algorithm' is deprecated in 1.6 and has no effect. "
516+
"It will be removed in version 1.8.",
532517
FutureWarning,
533518
)
534-
if not hasattr(self.estimator_, "predict_proba"):
535-
raise TypeError(
536-
"AdaBoostClassifier with algorithm='SAMME.R' requires "
537-
"that the weak learner supports the calculation of class "
538-
"probabilities with a predict_proba method.\n"
539-
"Please change the base estimator or set "
540-
"algorithm='SAMME' instead."
541-
)
542519

543520
if not has_fit_parameter(self.estimator_, "sample_weight"):
544521
raise ValueError(
545522
f"{self.estimator.__class__.__name__} doesn't support sample_weight."
546523
)
547524

548-
# TODO(1.6): Redefine the scope of the `_boost` and `_boost_discrete`
549-
# functions to be the same since SAMME will be the default value for the
550-
# "algorithm" parameter in version 1.6. Thus, a distinguishing function is
551-
# no longer needed. (Or adjust code here, if another algorithm, shall be
552-
# used instead of SAMME.R.)
553525
def _boost(self, iboost, X, y, sample_weight, random_state):
554526
"""Implement a single boost.
555527
556-
Perform a single boost according to the real multi-class SAMME.R
557-
algorithm or to the discrete SAMME algorithm and return the updated
558-
sample weights.
528+
Perform a single boost according to the discrete SAMME algorithm and return the
529+
updated sample weights.
559530
560531
Parameters
561532
----------
@@ -589,75 +560,6 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
589560
The classification error for the current boost.
590561
If None then boosting has terminated early.
591562
"""
592-
if self.algorithm == "SAMME.R":
593-
return self._boost_real(iboost, X, y, sample_weight, random_state)
594-
595-
else: # elif self.algorithm == "SAMME":
596-
return self._boost_discrete(iboost, X, y, sample_weight, random_state)
597-
598-
# TODO(1.6): Remove function. The `_boost_real` function won't be used any
599-
# longer, because the SAMME.R algorithm will be deprecated in 1.6.
600-
def _boost_real(self, iboost, X, y, sample_weight, random_state):
601-
"""Implement a single boost using the SAMME.R real algorithm."""
602-
estimator = self._make_estimator(random_state=random_state)
603-
604-
estimator.fit(X, y, sample_weight=sample_weight)
605-
606-
y_predict_proba = estimator.predict_proba(X)
607-
608-
if iboost == 0:
609-
self.classes_ = getattr(estimator, "classes_", None)
610-
self.n_classes_ = len(self.classes_)
611-
612-
y_predict = self.classes_.take(np.argmax(y_predict_proba, axis=1), axis=0)
613-
614-
# Instances incorrectly classified
615-
incorrect = y_predict != y
616-
617-
# Error fraction
618-
estimator_error = np.mean(np.average(incorrect, weights=sample_weight, axis=0))
619-
620-
# Stop if classification is perfect
621-
if estimator_error <= 0:
622-
return sample_weight, 1.0, 0.0
623-
624-
# Construct y coding as described in Zhu et al [2]:
625-
#
626-
# y_k = 1 if c == k else -1 / (K - 1)
627-
#
628-
# where K == n_classes_ and c, k in [0, K) are indices along the second
629-
# axis of the y coding with c being the index corresponding to the true
630-
# class label.
631-
n_classes = self.n_classes_
632-
classes = self.classes_
633-
y_codes = np.array([-1.0 / (n_classes - 1), 1.0])
634-
y_coding = y_codes.take(classes == y[:, np.newaxis])
635-
636-
# Displace zero probabilities so the log is defined.
637-
# Also fix negative elements which may occur with
638-
# negative sample weights.
639-
proba = y_predict_proba # alias for readability
640-
np.clip(proba, np.finfo(proba.dtype).eps, None, out=proba)
641-
642-
# Boost weight using multi-class AdaBoost SAMME.R alg
643-
estimator_weight = (
644-
-1.0
645-
* self.learning_rate
646-
* ((n_classes - 1.0) / n_classes)
647-
* xlogy(y_coding, y_predict_proba).sum(axis=1)
648-
)
649-
650-
# Only boost the weights if it will fit again
651-
if not iboost == self.n_estimators - 1:
652-
# Only boost positive weights
653-
sample_weight *= np.exp(
654-
estimator_weight * ((sample_weight > 0) | (estimator_weight < 0))
655-
)
656-
657-
return sample_weight, 1.0, estimator_error
658-
659-
def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
660-
"""Implement a single boost using the SAMME discrete algorithm."""
661563
estimator = self._make_estimator(random_state=random_state)
662564

663565
estimator.fit(X, y, sample_weight=sample_weight)
@@ -789,21 +691,17 @@ class in ``classes_``, respectively.
789691
n_classes = self.n_classes_
790692
classes = self.classes_[:, np.newaxis]
791693

792-
# TODO(1.6): Remove, because "algorithm" param will be deprecated in 1.6
793-
if self.algorithm == "SAMME.R":
794-
# The weights are all 1. for SAMME.R
795-
pred = sum(
796-
_samme_proba(estimator, n_classes, X) for estimator in self.estimators_
797-
)
798-
else: # self.algorithm == "SAMME"
799-
pred = sum(
800-
np.where(
801-
(estimator.predict(X) == classes).T,
802-
w,
803-
-1 / (n_classes - 1) * w,
804-
)
805-
for estimator, w in zip(self.estimators_, self.estimator_weights_)
694+
if n_classes == 1:
695+
return np.zeros_like(X, shape=(X.shape[0], 1))
696+
697+
pred = sum(
698+
np.where(
699+
(estimator.predict(X) == classes).T,
700+
w,
701+
-1 / (n_classes - 1) * w,
806702
)
703+
for estimator, w in zip(self.estimators_, self.estimator_weights_)
704+
)
807705

808706
pred /= self.estimator_weights_.sum()
809707
if n_classes == 2:
@@ -844,17 +742,11 @@ class in ``classes_``, respectively.
844742
for weight, estimator in zip(self.estimator_weights_, self.estimators_):
845743
norm += weight
846744

847-
# TODO(1.6): Remove, because "algorithm" param will be deprecated in
848-
# 1.6
849-
if self.algorithm == "SAMME.R":
850-
# The weights are all 1. for SAMME.R
851-
current_pred = _samme_proba(estimator, n_classes, X)
852-
else: # elif self.algorithm == "SAMME":
853-
current_pred = np.where(
854-
(estimator.predict(X) == classes).T,
855-
weight,
856-
-1 / (n_classes - 1) * weight,
857-
)
745+
current_pred = np.where(
746+
(estimator.predict(X) == classes).T,
747+
weight,
748+
-1 / (n_classes - 1) * weight,
749+
)
858750

859751
if pred is None:
860752
pred = current_pred

sklearn/ensemble/tests/test_bagging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def test_bagging_with_metadata_routing(model):
965965
"model",
966966
[
967967
BaggingClassifier(
968-
estimator=AdaBoostClassifier(n_estimators=1, algorithm="SAMME"),
968+
estimator=AdaBoostClassifier(n_estimators=1),
969969
n_estimators=1,
970970
),
971971
BaggingRegressor(estimator=AdaBoostRegressor(n_estimators=1), n_estimators=1),

0 commit comments

Comments
 (0)