24
24
from numbers import Integral , Real
25
25
26
26
import numpy as np
27
- from scipy .special import xlogy
28
27
29
28
from ..base import (
30
29
ClassifierMixin ,
36
35
from ..metrics import accuracy_score , r2_score
37
36
from ..tree import DecisionTreeClassifier , DecisionTreeRegressor
38
37
from ..utils import _safe_indexing , check_random_state
39
- from ..utils ._param_validation import HasMethods , Interval , StrOptions
38
+ from ..utils ._param_validation import HasMethods , Hidden , Interval , StrOptions
40
39
from ..utils .extmath import softmax , stable_cumsum
41
40
from ..utils .metadata_routing import (
42
41
_raise_for_unsupported_routing ,
@@ -375,16 +374,12 @@ class AdaBoostClassifier(
375
374
a trade-off between the `learning_rate` and `n_estimators` parameters.
376
375
Values must be in the range `(0.0, inf)`.
377
376
378
- algorithm : {'SAMME', 'SAMME.R'}, default='SAMME.R'
379
- If 'SAMME.R' then use the SAMME.R real boosting algorithm.
380
- ``estimator`` must support calculation of class probabilities.
381
- If 'SAMME' then use the SAMME discrete boosting algorithm.
382
- The SAMME.R algorithm typically converges faster than SAMME,
383
- achieving a lower test error with fewer boosting iterations.
377
+ algorithm : {'SAMME'}, default='SAMME'
378
+ Use the SAMME discrete boosting algorithm.
384
379
385
- .. deprecated:: 1.4
386
- `"SAMME.R" ` is deprecated and will be removed in version 1.6.
387
- '"SAMME"' will become the default .
380
+ .. deprecated:: 1.6
381
+ `algorithm ` is deprecated and will be removed in version 1.8. This
382
+ estimator only implements the 'SAMME' algorithm .
388
383
389
384
random_state : int, RandomState instance or None, default=None
390
385
Controls the random seed given at each `estimator` at each
@@ -470,9 +465,9 @@ class AdaBoostClassifier(
470
465
>>> X, y = make_classification(n_samples=1000, n_features=4,
471
466
... n_informative=2, n_redundant=0,
472
467
... random_state=0, shuffle=False)
473
- >>> clf = AdaBoostClassifier(n_estimators=100, algorithm="SAMME", random_state=0)
468
+ >>> clf = AdaBoostClassifier(n_estimators=100, random_state=0)
474
469
>>> clf.fit(X, y)
475
- AdaBoostClassifier(algorithm='SAMME', n_estimators=100, random_state=0)
470
+ AdaBoostClassifier(n_estimators=100, random_state=0)
476
471
>>> clf.predict([[0, 0, 0, 0]])
477
472
array([1])
478
473
>>> clf.score(X, y)
@@ -487,23 +482,19 @@ class AdaBoostClassifier(
487
482
refer to :ref:`sphx_glr_auto_examples_ensemble_plot_adaboost_twoclass.py`.
488
483
"""
489
484
490
- # TODO(1.6): Modify _parameter_constraints for "algorithm" to only check
491
- # for "SAMME"
485
+ # TODO(1.8): remove "algorithm" entry
492
486
_parameter_constraints : dict = {
493
487
** BaseWeightBoosting ._parameter_constraints ,
494
- "algorithm" : [
495
- StrOptions ({"SAMME" , "SAMME.R" }),
496
- ],
488
+ "algorithm" : [StrOptions ({"SAMME" }), Hidden (StrOptions ({"deprecated" }))],
497
489
}
498
490
499
- # TODO(1.6): Change default "algorithm" value to "SAMME"
500
491
def __init__ (
501
492
self ,
502
493
estimator = None ,
503
494
* ,
504
495
n_estimators = 50 ,
505
496
learning_rate = 1.0 ,
506
- algorithm = "SAMME.R " ,
497
+ algorithm = "deprecated " ,
507
498
random_state = None ,
508
499
):
509
500
super ().__init__ (
@@ -519,43 +510,23 @@ def _validate_estimator(self):
519
510
"""Check the estimator and set the estimator_ attribute."""
520
511
super ()._validate_estimator (default = DecisionTreeClassifier (max_depth = 1 ))
521
512
522
- # TODO(1.6): Remove, as "SAMME.R" value for "algorithm" param will be
523
- # removed in 1.6
524
- # SAMME-R requires predict_proba-enabled base estimators
525
- if self .algorithm != "SAMME" :
513
+ if self .algorithm != "deprecated" :
526
514
warnings .warn (
527
- (
528
- "The SAMME.R algorithm (the default) is deprecated and will be"
529
- " removed in 1.6. Use the SAMME algorithm to circumvent this"
530
- " warning."
531
- ),
515
+ "The parameter 'algorithm' is deprecated in 1.6 and has no effect. "
516
+ "It will be removed in version 1.8." ,
532
517
FutureWarning ,
533
518
)
534
- if not hasattr (self .estimator_ , "predict_proba" ):
535
- raise TypeError (
536
- "AdaBoostClassifier with algorithm='SAMME.R' requires "
537
- "that the weak learner supports the calculation of class "
538
- "probabilities with a predict_proba method.\n "
539
- "Please change the base estimator or set "
540
- "algorithm='SAMME' instead."
541
- )
542
519
543
520
if not has_fit_parameter (self .estimator_ , "sample_weight" ):
544
521
raise ValueError (
545
522
f"{ self .estimator .__class__ .__name__ } doesn't support sample_weight."
546
523
)
547
524
548
- # TODO(1.6): Redefine the scope of the `_boost` and `_boost_discrete`
549
- # functions to be the same since SAMME will be the default value for the
550
- # "algorithm" parameter in version 1.6. Thus, a distinguishing function is
551
- # no longer needed. (Or adjust code here, if another algorithm, shall be
552
- # used instead of SAMME.R.)
553
525
def _boost (self , iboost , X , y , sample_weight , random_state ):
554
526
"""Implement a single boost.
555
527
556
- Perform a single boost according to the real multi-class SAMME.R
557
- algorithm or to the discrete SAMME algorithm and return the updated
558
- sample weights.
528
+ Perform a single boost according to the discrete SAMME algorithm and return the
529
+ updated sample weights.
559
530
560
531
Parameters
561
532
----------
@@ -589,75 +560,6 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
589
560
The classification error for the current boost.
590
561
If None then boosting has terminated early.
591
562
"""
592
- if self .algorithm == "SAMME.R" :
593
- return self ._boost_real (iboost , X , y , sample_weight , random_state )
594
-
595
- else : # elif self.algorithm == "SAMME":
596
- return self ._boost_discrete (iboost , X , y , sample_weight , random_state )
597
-
598
- # TODO(1.6): Remove function. The `_boost_real` function won't be used any
599
- # longer, because the SAMME.R algorithm will be deprecated in 1.6.
600
- def _boost_real (self , iboost , X , y , sample_weight , random_state ):
601
- """Implement a single boost using the SAMME.R real algorithm."""
602
- estimator = self ._make_estimator (random_state = random_state )
603
-
604
- estimator .fit (X , y , sample_weight = sample_weight )
605
-
606
- y_predict_proba = estimator .predict_proba (X )
607
-
608
- if iboost == 0 :
609
- self .classes_ = getattr (estimator , "classes_" , None )
610
- self .n_classes_ = len (self .classes_ )
611
-
612
- y_predict = self .classes_ .take (np .argmax (y_predict_proba , axis = 1 ), axis = 0 )
613
-
614
- # Instances incorrectly classified
615
- incorrect = y_predict != y
616
-
617
- # Error fraction
618
- estimator_error = np .mean (np .average (incorrect , weights = sample_weight , axis = 0 ))
619
-
620
- # Stop if classification is perfect
621
- if estimator_error <= 0 :
622
- return sample_weight , 1.0 , 0.0
623
-
624
- # Construct y coding as described in Zhu et al [2]:
625
- #
626
- # y_k = 1 if c == k else -1 / (K - 1)
627
- #
628
- # where K == n_classes_ and c, k in [0, K) are indices along the second
629
- # axis of the y coding with c being the index corresponding to the true
630
- # class label.
631
- n_classes = self .n_classes_
632
- classes = self .classes_
633
- y_codes = np .array ([- 1.0 / (n_classes - 1 ), 1.0 ])
634
- y_coding = y_codes .take (classes == y [:, np .newaxis ])
635
-
636
- # Displace zero probabilities so the log is defined.
637
- # Also fix negative elements which may occur with
638
- # negative sample weights.
639
- proba = y_predict_proba # alias for readability
640
- np .clip (proba , np .finfo (proba .dtype ).eps , None , out = proba )
641
-
642
- # Boost weight using multi-class AdaBoost SAMME.R alg
643
- estimator_weight = (
644
- - 1.0
645
- * self .learning_rate
646
- * ((n_classes - 1.0 ) / n_classes )
647
- * xlogy (y_coding , y_predict_proba ).sum (axis = 1 )
648
- )
649
-
650
- # Only boost the weights if it will fit again
651
- if not iboost == self .n_estimators - 1 :
652
- # Only boost positive weights
653
- sample_weight *= np .exp (
654
- estimator_weight * ((sample_weight > 0 ) | (estimator_weight < 0 ))
655
- )
656
-
657
- return sample_weight , 1.0 , estimator_error
658
-
659
- def _boost_discrete (self , iboost , X , y , sample_weight , random_state ):
660
- """Implement a single boost using the SAMME discrete algorithm."""
661
563
estimator = self ._make_estimator (random_state = random_state )
662
564
663
565
estimator .fit (X , y , sample_weight = sample_weight )
@@ -789,21 +691,17 @@ class in ``classes_``, respectively.
789
691
n_classes = self .n_classes_
790
692
classes = self .classes_ [:, np .newaxis ]
791
693
792
- # TODO(1.6): Remove, because "algorithm" param will be deprecated in 1.6
793
- if self .algorithm == "SAMME.R" :
794
- # The weights are all 1. for SAMME.R
795
- pred = sum (
796
- _samme_proba (estimator , n_classes , X ) for estimator in self .estimators_
797
- )
798
- else : # self.algorithm == "SAMME"
799
- pred = sum (
800
- np .where (
801
- (estimator .predict (X ) == classes ).T ,
802
- w ,
803
- - 1 / (n_classes - 1 ) * w ,
804
- )
805
- for estimator , w in zip (self .estimators_ , self .estimator_weights_ )
694
+ if n_classes == 1 :
695
+ return np .zeros_like (X , shape = (X .shape [0 ], 1 ))
696
+
697
+ pred = sum (
698
+ np .where (
699
+ (estimator .predict (X ) == classes ).T ,
700
+ w ,
701
+ - 1 / (n_classes - 1 ) * w ,
806
702
)
703
+ for estimator , w in zip (self .estimators_ , self .estimator_weights_ )
704
+ )
807
705
808
706
pred /= self .estimator_weights_ .sum ()
809
707
if n_classes == 2 :
@@ -844,17 +742,11 @@ class in ``classes_``, respectively.
844
742
for weight , estimator in zip (self .estimator_weights_ , self .estimators_ ):
845
743
norm += weight
846
744
847
- # TODO(1.6): Remove, because "algorithm" param will be deprecated in
848
- # 1.6
849
- if self .algorithm == "SAMME.R" :
850
- # The weights are all 1. for SAMME.R
851
- current_pred = _samme_proba (estimator , n_classes , X )
852
- else : # elif self.algorithm == "SAMME":
853
- current_pred = np .where (
854
- (estimator .predict (X ) == classes ).T ,
855
- weight ,
856
- - 1 / (n_classes - 1 ) * weight ,
857
- )
745
+ current_pred = np .where (
746
+ (estimator .predict (X ) == classes ).T ,
747
+ weight ,
748
+ - 1 / (n_classes - 1 ) * weight ,
749
+ )
858
750
859
751
if pred is None :
860
752
pred = current_pred
0 commit comments