forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline.py
1732 lines (1414 loc) · 63.1 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
The :mod:`sklearn.pipeline` module implements utilities to build a composite
estimator, as a chain of transforms and estimators.
"""
# Author: Edouard Duchesnay
# Gael Varoquaux
# Virgile Fritsch
# Alexandre Gramfort
# Lars Buitinck
# License: BSD
from collections import defaultdict
from itertools import islice
import numpy as np
from scipy import sparse
from .base import TransformerMixin, _fit_context, clone
from .exceptions import NotFittedError
from .preprocessing import FunctionTransformer
from .utils import Bunch, _print_elapsed_time, check_pandas_support
from .utils._estimator_html_repr import _VisualBlock
from .utils._metadata_requests import METHODS
from .utils._param_validation import HasMethods, Hidden
from .utils._set_output import _get_output_config, _safe_set_output
from .utils._tags import _safe_tags
from .utils.metadata_routing import (
MetadataRouter,
MethodMapping,
_raise_for_params,
_routing_enabled,
process_routing,
)
from .utils.metaestimators import _BaseComposition, available_if
from .utils.parallel import Parallel, delayed
from .utils.validation import check_is_fitted, check_memory
__all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"]
def _final_estimator_has(attr):
"""Check that final_estimator has `attr`.
Used together with `available_if` in `Pipeline`."""
def check(self):
# raise original `AttributeError` if `attr` does not exist
getattr(self._final_estimator, attr)
return True
return check
class Pipeline(_BaseComposition):
"""
Pipeline of transforms with a final estimator.
Sequentially apply a list of transforms and a final estimator.
Intermediate steps of the pipeline must be 'transforms', that is, they
must implement `fit` and `transform` methods.
The final estimator only needs to implement `fit`.
The transformers in the pipeline can be cached using ``memory`` argument.
The purpose of the pipeline is to assemble several steps that can be
cross-validated together while setting different parameters. For this, it
enables setting parameters of the various steps using their names and the
parameter name separated by a `'__'`, as in the example below. A step's
estimator may be replaced entirely by setting the parameter with its name
to another estimator, or a transformer removed by setting it to
`'passthrough'` or `None`.
Read more in the :ref:`User Guide <pipeline>`.
.. versionadded:: 0.5
Parameters
----------
steps : list of tuple
List of (name, transform) tuples (implementing `fit`/`transform`) that
are chained in sequential order. The last transform must be an
estimator.
memory : str or object with the joblib.Memory interface, default=None
Used to cache the fitted transformers of the pipeline. The last step
will never be cached, even if it is a transformer. By default, no
caching is performed. If a string is given, it is the path to the
caching directory. Enabling caching triggers a clone of the transformers
before fitting. Therefore, the transformer instance given to the
pipeline cannot be inspected directly. Use the attribute ``named_steps``
or ``steps`` to inspect estimators within the pipeline. Caching the
transformers is advantageous when fitting is time consuming.
verbose : bool, default=False
If True, the time elapsed while fitting each step will be printed as it
is completed.
Attributes
----------
named_steps : :class:`~sklearn.utils.Bunch`
Dictionary-like object, with the following attributes.
Read-only attribute to access any step parameter by user given name.
Keys are step names and values are steps parameters.
classes_ : ndarray of shape (n_classes,)
The classes labels. Only exist if the last step of the pipeline is a
classifier.
n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying first estimator in `steps` exposes such an attribute
when fit.
.. versionadded:: 0.24
feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during :term:`fit`. Only defined if the
underlying estimator exposes such an attribute when fit.
.. versionadded:: 1.0
See Also
--------
make_pipeline : Convenience function for simplified pipeline construction.
Examples
--------
>>> from sklearn.svm import SVC
>>> from sklearn.preprocessing import StandardScaler
>>> from sklearn.datasets import make_classification
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.pipeline import Pipeline
>>> X, y = make_classification(random_state=0)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
... random_state=0)
>>> pipe = Pipeline([('scaler', StandardScaler()), ('svc', SVC())])
>>> # The pipeline can be used as any other estimator
>>> # and avoids leaking the test set into the train set
>>> pipe.fit(X_train, y_train).score(X_test, y_test)
0.88
>>> # An estimator's parameter can be set using '__' syntax
>>> pipe.set_params(svc__C=10).fit(X_train, y_train).score(X_test, y_test)
0.76
"""
# BaseEstimator interface
_required_parameters = ["steps"]
_parameter_constraints: dict = {
"steps": [list, Hidden(tuple)],
"memory": [None, str, HasMethods(["cache"])],
"verbose": ["boolean"],
}
def __init__(self, steps, *, memory=None, verbose=False):
self.steps = steps
self.memory = memory
self.verbose = verbose
def set_output(self, *, transform=None):
"""Set the output container when `"transform"` and `"fit_transform"` are called.
Calling `set_output` will set the output of all estimators in `steps`.
Parameters
----------
transform : {"default", "pandas"}, default=None
Configure output of `transform` and `fit_transform`.
- `"default"`: Default output format of a transformer
- `"pandas"`: DataFrame output
- `None`: Transform configuration is unchanged
Returns
-------
self : estimator instance
Estimator instance.
"""
for _, _, step in self._iter():
_safe_set_output(step, transform=transform)
return self
def get_params(self, deep=True):
"""Get parameters for this estimator.
Returns the parameters given in the constructor as well as the
estimators contained within the `steps` of the `Pipeline`.
Parameters
----------
deep : bool, default=True
If True, will return the parameters for this estimator and
contained subobjects that are estimators.
Returns
-------
params : mapping of string to any
Parameter names mapped to their values.
"""
return self._get_params("steps", deep=deep)
def set_params(self, **kwargs):
"""Set the parameters of this estimator.
Valid parameter keys can be listed with ``get_params()``. Note that
you can directly set the parameters of the estimators contained in
`steps`.
Parameters
----------
**kwargs : dict
Parameters of this estimator or parameters of estimators contained
in `steps`. Parameters of the steps may be set using its name and
the parameter name separated by a '__'.
Returns
-------
self : object
Pipeline class instance.
"""
self._set_params("steps", **kwargs)
return self
def _validate_steps(self):
names, estimators = zip(*self.steps)
# validate names
self._validate_names(names)
# validate estimators
transformers = estimators[:-1]
estimator = estimators[-1]
for t in transformers:
if t is None or t == "passthrough":
continue
if not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not hasattr(
t, "transform"
):
raise TypeError(
"All intermediate steps should be "
"transformers and implement fit and transform "
"or be the string 'passthrough' "
"'%s' (type %s) doesn't" % (t, type(t))
)
# We allow last estimator to be None as an identity transformation
if (
estimator is not None
and estimator != "passthrough"
and not hasattr(estimator, "fit")
):
raise TypeError(
"Last step of Pipeline should implement fit "
"or be the string 'passthrough'. "
"'%s' (type %s) doesn't" % (estimator, type(estimator))
)
def _iter(self, with_final=True, filter_passthrough=True):
"""
Generate (idx, (name, trans)) tuples from self.steps
When filter_passthrough is True, 'passthrough' and None transformers
are filtered out.
"""
stop = len(self.steps)
if not with_final:
stop -= 1
for idx, (name, trans) in enumerate(islice(self.steps, 0, stop)):
if not filter_passthrough:
yield idx, name, trans
elif trans is not None and trans != "passthrough":
yield idx, name, trans
def __len__(self):
"""
Returns the length of the Pipeline
"""
return len(self.steps)
def __getitem__(self, ind):
"""Returns a sub-pipeline or a single estimator in the pipeline
Indexing with an integer will return an estimator; using a slice
returns another Pipeline instance which copies a slice of this
Pipeline. This copy is shallow: modifying (or fitting) estimators in
the sub-pipeline will affect the larger pipeline and vice-versa.
However, replacing a value in `step` will not affect a copy.
"""
if isinstance(ind, slice):
if ind.step not in (1, None):
raise ValueError("Pipeline slicing only supports a step of 1")
return self.__class__(
self.steps[ind], memory=self.memory, verbose=self.verbose
)
try:
name, est = self.steps[ind]
except TypeError:
# Not an int, try get step by name
return self.named_steps[ind]
return est
@property
def _estimator_type(self):
return self.steps[-1][1]._estimator_type
@property
def named_steps(self):
"""Access the steps by name.
Read-only attribute to access any step by given name.
Keys are steps names and values are the steps objects."""
# Use Bunch object to improve autocomplete
return Bunch(**dict(self.steps))
@property
def _final_estimator(self):
try:
estimator = self.steps[-1][1]
return "passthrough" if estimator is None else estimator
except (ValueError, AttributeError, TypeError):
# This condition happens when a call to a method is first calling
# `_available_if` and `fit` did not validate `steps` yet. We
# return `None` and an `InvalidParameterError` will be raised
# right after.
return None
def _log_message(self, step_idx):
if not self.verbose:
return None
name, _ = self.steps[step_idx]
return "(step %d of %d) Processing %s" % (step_idx + 1, len(self.steps), name)
def _check_method_params(self, method, props, **kwargs):
if _routing_enabled():
routed_params = process_routing(
self, method=method, other_params=props, **kwargs
)
return routed_params
else:
fit_params_steps = Bunch(
**{
name: Bunch(**{method: {} for method in METHODS})
for name, step in self.steps
if step is not None
}
)
for pname, pval in props.items():
if "__" not in pname:
raise ValueError(
"Pipeline.fit does not accept the {} parameter. "
"You can pass parameters to specific steps of your "
"pipeline using the stepname__parameter format, e.g. "
"`Pipeline.fit(X, y, logisticregression__sample_weight"
"=sample_weight)`.".format(pname)
)
step, param = pname.split("__", 1)
fit_params_steps[step]["fit"][param] = pval
# without metadata routing, fit_transform and fit_predict
# get all the same params and pass it to the last fit.
fit_params_steps[step]["fit_transform"][param] = pval
fit_params_steps[step]["fit_predict"][param] = pval
return fit_params_steps
# Estimator interface
def _fit(self, X, y=None, routed_params=None):
# shallow copy of steps - this should really be steps_
self.steps = list(self.steps)
self._validate_steps()
# Setup the memory
memory = check_memory(self.memory)
fit_transform_one_cached = memory.cache(_fit_transform_one)
for step_idx, name, transformer in self._iter(
with_final=False, filter_passthrough=False
):
if transformer is None or transformer == "passthrough":
with _print_elapsed_time("Pipeline", self._log_message(step_idx)):
continue
if hasattr(memory, "location") and memory.location is None:
# we do not clone when caching is disabled to
# preserve backward compatibility
cloned_transformer = transformer
else:
cloned_transformer = clone(transformer)
# Fit or load from cache the current transformer
X, fitted_transformer = fit_transform_one_cached(
cloned_transformer,
X,
y,
None,
message_clsname="Pipeline",
message=self._log_message(step_idx),
params=routed_params[name],
)
# Replace the transformer of the step with the fitted
# transformer. This is necessary when loading the transformer
# from the cache.
self.steps[step_idx] = (name, fitted_transformer)
return X
@_fit_context(
# estimators in Pipeline.steps are not validated yet
prefer_skip_nested_validation=False
)
def fit(self, X, y=None, **params):
"""Fit the model.
Fit all the transformers one after the other and transform the
data. Finally, fit the transformed data using the final estimator.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of the
pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps of
the pipeline.
**params : dict of str -> object
- If `enable_metadata_routing=False` (default):
Parameters passed to the ``fit`` method of each step, where
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.
- If `enable_metadata_routing=True`:
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionchanged:: 1.4
Parameters are now passed to the ``transform`` method of the
intermediate steps as well, if requested, and if
`enable_metadata_routing=True` is set via
:func:`~sklearn.set_config`.
See :ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Returns
-------
self : object
Pipeline with fitted steps.
"""
routed_params = self._check_method_params(method="fit", props=params)
Xt = self._fit(X, y, routed_params)
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator != "passthrough":
last_step_params = routed_params[self.steps[-1][0]]
self._final_estimator.fit(Xt, y, **last_step_params["fit"])
return self
def _can_fit_transform(self):
return (
self._final_estimator == "passthrough"
or hasattr(self._final_estimator, "transform")
or hasattr(self._final_estimator, "fit_transform")
)
@available_if(_can_fit_transform)
@_fit_context(
# estimators in Pipeline.steps are not validated yet
prefer_skip_nested_validation=False
)
def fit_transform(self, X, y=None, **params):
"""Fit the model and transform with the final estimator.
Fits all the transformers one after the other and transform the
data. Then uses `fit_transform` on transformed data with the final
estimator.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of the
pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps of
the pipeline.
**params : dict of str -> object
- If `enable_metadata_routing=False` (default):
Parameters passed to the ``fit`` method of each step, where
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.
- If `enable_metadata_routing=True`:
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionchanged:: 1.4
Parameters are now passed to the ``transform`` method of the
intermediate steps as well, if requested, and if
`enable_metadata_routing=True`.
See :ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Returns
-------
Xt : ndarray of shape (n_samples, n_transformed_features)
Transformed samples.
"""
routed_params = self._check_method_params(method="fit_transform", props=params)
Xt = self._fit(X, y, routed_params)
last_step = self._final_estimator
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if last_step == "passthrough":
return Xt
last_step_params = routed_params[self.steps[-1][0]]
if hasattr(last_step, "fit_transform"):
return last_step.fit_transform(
Xt, y, **last_step_params["fit_transform"]
)
else:
return last_step.fit(Xt, y, **last_step_params["fit"]).transform(
Xt, **last_step_params["transform"]
)
@available_if(_final_estimator_has("predict"))
def predict(self, X, **params):
"""Transform the data, and apply `predict` with the final estimator.
Call `transform` of each transformer in the pipeline. The transformed
data are finally passed to the final estimator that calls `predict`
method. Only valid if the final estimator implements `predict`.
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
**params : dict of str -> object
- If `enable_metadata_routing=False` (default):
Parameters to the ``predict`` called at the end of all
transformations in the pipeline.
- If `enable_metadata_routing=True`:
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionadded:: 0.20
.. versionchanged:: 1.4
Parameters are now passed to the ``transform`` method of the
intermediate steps as well, if requested, and if
`enable_metadata_routing=True` is set via
:func:`~sklearn.set_config`.
See :ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Note that while this may be used to return uncertainties from some
models with ``return_std`` or ``return_cov``, uncertainties that are
generated by the transformations in the pipeline are not propagated
to the final estimator.
Returns
-------
y_pred : ndarray
Result of calling `predict` on the final estimator.
"""
Xt = X
if not _routing_enabled():
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt)
return self.steps[-1][1].predict(Xt, **params)
# metadata routing enabled
routed_params = process_routing(self, "predict", other_params=params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict(Xt, **routed_params[self.steps[-1][0]].predict)
@available_if(_final_estimator_has("fit_predict"))
@_fit_context(
# estimators in Pipeline.steps are not validated yet
prefer_skip_nested_validation=False
)
def fit_predict(self, X, y=None, **params):
"""Transform the data, and apply `fit_predict` with the final estimator.
Call `fit_transform` of each transformer in the pipeline. The
transformed data are finally passed to the final estimator that calls
`fit_predict` method. Only valid if the final estimator implements
`fit_predict`.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of
the pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps
of the pipeline.
**params : dict of str -> object
- If `enable_metadata_routing=False` (default):
Parameters to the ``predict`` called at the end of all
transformations in the pipeline.
- If `enable_metadata_routing=True`:
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionadded:: 0.20
.. versionchanged:: 1.4
Parameters are now passed to the ``transform`` method of the
intermediate steps as well, if requested, and if
`enable_metadata_routing=True`.
See :ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Note that while this may be used to return uncertainties from some
models with ``return_std`` or ``return_cov``, uncertainties that are
generated by the transformations in the pipeline are not propagated
to the final estimator.
Returns
-------
y_pred : ndarray
Result of calling `fit_predict` on the final estimator.
"""
routed_params = self._check_method_params(method="fit_predict", props=params)
Xt = self._fit(X, y, routed_params)
params_last_step = routed_params[self.steps[-1][0]]
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
y_pred = self.steps[-1][1].fit_predict(
Xt, y, **params_last_step.get("fit_predict", {})
)
return y_pred
@available_if(_final_estimator_has("predict_proba"))
def predict_proba(self, X, **params):
"""Transform the data, and apply `predict_proba` with the final estimator.
Call `transform` of each transformer in the pipeline. The transformed
data are finally passed to the final estimator that calls
`predict_proba` method. Only valid if the final estimator implements
`predict_proba`.
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
**params : dict of str -> object
- If `enable_metadata_routing=False` (default):
Parameters to the `predict_proba` called at the end of all
transformations in the pipeline.
- If `enable_metadata_routing=True`:
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionadded:: 0.20
.. versionchanged:: 1.4
Parameters are now passed to the ``transform`` method of the
intermediate steps as well, if requested, and if
`enable_metadata_routing=True`.
See :ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Returns
-------
y_proba : ndarray of shape (n_samples, n_classes)
Result of calling `predict_proba` on the final estimator.
"""
Xt = X
if not _routing_enabled():
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt)
return self.steps[-1][1].predict_proba(Xt, **params)
# metadata routing enabled
routed_params = process_routing(self, "predict_proba", other_params=params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict_proba(
Xt, **routed_params[self.steps[-1][0]].predict_proba
)
@available_if(_final_estimator_has("decision_function"))
def decision_function(self, X, **params):
"""Transform the data, and apply `decision_function` with the final estimator.
Call `transform` of each transformer in the pipeline. The transformed
data are finally passed to the final estimator that calls
`decision_function` method. Only valid if the final estimator
implements `decision_function`.
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
**params : dict of string -> object
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionadded:: 1.4
Only available if `enable_metadata_routing=True`. See
:ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Returns
-------
y_score : ndarray of shape (n_samples, n_classes)
Result of calling `decision_function` on the final estimator.
"""
_raise_for_params(params, self, "decision_function")
# not branching here since params is only available if
# enable_metadata_routing=True
routed_params = process_routing(self, "decision_function", other_params=params)
Xt = X
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(
Xt, **routed_params.get(name, {}).get("transform", {})
)
return self.steps[-1][1].decision_function(
Xt, **routed_params.get(self.steps[-1][0], {}).get("decision_function", {})
)
@available_if(_final_estimator_has("score_samples"))
def score_samples(self, X):
"""Transform the data, and apply `score_samples` with the final estimator.
Call `transform` of each transformer in the pipeline. The transformed
data are finally passed to the final estimator that calls
`score_samples` method. Only valid if the final estimator implements
`score_samples`.
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
Returns
-------
y_score : ndarray of shape (n_samples,)
Result of calling `score_samples` on the final estimator.
"""
Xt = X
for _, _, transformer in self._iter(with_final=False):
Xt = transformer.transform(Xt)
return self.steps[-1][1].score_samples(Xt)
@available_if(_final_estimator_has("predict_log_proba"))
def predict_log_proba(self, X, **params):
"""Transform the data, and apply `predict_log_proba` with the final estimator.
Call `transform` of each transformer in the pipeline. The transformed
data are finally passed to the final estimator that calls
`predict_log_proba` method. Only valid if the final estimator
implements `predict_log_proba`.
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
**params : dict of str -> object
- If `enable_metadata_routing=False` (default):
Parameters to the `predict_log_proba` called at the end of all
transformations in the pipeline.
- If `enable_metadata_routing=True`:
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionadded:: 0.20
.. versionchanged:: 1.4
Parameters are now passed to the ``transform`` method of the
intermediate steps as well, if requested, and if
`enable_metadata_routing=True`.
See :ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Returns
-------
y_log_proba : ndarray of shape (n_samples, n_classes)
Result of calling `predict_log_proba` on the final estimator.
"""
Xt = X
if not _routing_enabled():
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt)
return self.steps[-1][1].predict_log_proba(Xt, **params)
# metadata routing enabled
routed_params = process_routing(self, "predict_log_proba", other_params=params)
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].predict_log_proba(
Xt, **routed_params[self.steps[-1][0]].predict_log_proba
)
def _can_transform(self):
return self._final_estimator == "passthrough" or hasattr(
self._final_estimator, "transform"
)
@available_if(_can_transform)
def transform(self, X, **params):
"""Transform the data, and apply `transform` with the final estimator.
Call `transform` of each transformer in the pipeline. The transformed
data are finally passed to the final estimator that calls
`transform` method. Only valid if the final estimator
implements `transform`.
This also works where final estimator is `None` in which case all prior
transformations are applied.
Parameters
----------
X : iterable
Data to transform. Must fulfill input requirements of first step
of the pipeline.
**params : dict of str -> object
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionadded:: 1.4
Only available if `enable_metadata_routing=True`. See
:ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Returns
-------
Xt : ndarray of shape (n_samples, n_transformed_features)
Transformed data.
"""
_raise_for_params(params, self, "transform")
# not branching here since params is only available if
# enable_metadata_routing=True
routed_params = process_routing(self, "transform", other_params=params)
Xt = X
for _, name, transform in self._iter():
Xt = transform.transform(Xt, **routed_params[name].transform)
return Xt
def _can_inverse_transform(self):
return all(hasattr(t, "inverse_transform") for _, _, t in self._iter())
@available_if(_can_inverse_transform)
def inverse_transform(self, Xt, **params):
"""Apply `inverse_transform` for each step in a reverse order.
All estimators in the pipeline must support `inverse_transform`.
Parameters
----------
Xt : array-like of shape (n_samples, n_transformed_features)
Data samples, where ``n_samples`` is the number of samples and
``n_features`` is the number of features. Must fulfill
input requirements of last step of pipeline's
``inverse_transform`` method.
**params : dict of str -> object
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionadded:: 1.4
Only available if `enable_metadata_routing=True`. See
:ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Returns
-------
Xt : ndarray of shape (n_samples, n_features)
Inverse transformed data, that is, data in the original feature
space.
"""
_raise_for_params(params, self, "inverse_transform")
# we don't have to branch here, since params is only non-empty if
# enable_metadata_routing=True.
routed_params = process_routing(self, "inverse_transform", other_params=params)
reverse_iter = reversed(list(self._iter()))
for _, name, transform in reverse_iter:
Xt = transform.inverse_transform(
Xt, **routed_params[name].inverse_transform
)
return Xt
@available_if(_final_estimator_has("score"))
def score(self, X, y=None, sample_weight=None, **params):
"""Transform the data, and apply `score` with the final estimator.
Call `transform` of each transformer in the pipeline. The transformed
data are finally passed to the final estimator that calls
`score` method. Only valid if the final estimator implements `score`.
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
y : iterable, default=None
Targets used for scoring. Must fulfill label requirements for all
steps of the pipeline.
sample_weight : array-like, default=None
If not None, this argument is passed as ``sample_weight`` keyword
argument to the ``score`` method of the final estimator.
**params : dict of str -> object
Parameters requested and accepted by steps. Each step must have
requested certain metadata for these parameters to be forwarded to
them.
.. versionadded:: 1.4
Only available if `enable_metadata_routing=True`. See
:ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
Returns
-------
score : float
Result of calling `score` on the final estimator.
"""
Xt = X
if not _routing_enabled():
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt)
score_params = {}
if sample_weight is not None:
score_params["sample_weight"] = sample_weight
return self.steps[-1][1].score(Xt, y, **score_params)
# metadata routing is enabled.
routed_params = process_routing(
self, "score", sample_weight=sample_weight, other_params=params
)
Xt = X
for _, name, transform in self._iter(with_final=False):
Xt = transform.transform(Xt, **routed_params[name].transform)
return self.steps[-1][1].score(Xt, y, **routed_params[self.steps[-1][0]].score)
@property
def classes_(self):
"""The classes labels. Only exist if the last step is a classifier."""
return self.steps[-1][1].classes_
def _more_tags(self):
tags = {
"_xfail_checks": {
"check_dont_overwrite_parameters": (