-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy path_smote_enn.py
160 lines (127 loc) · 4.94 KB
/
_smote_enn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Class to perform over-sampling using SMOTE and cleaning using ENN."""
# Authors: Guillaume Lemaitre <[email protected]>
# Christos Aridas
# License: MIT
import numbers
from sklearn.base import clone
from sklearn.utils import check_X_y
from ..base import BaseSampler
from ..over_sampling import SMOTE
from ..over_sampling.base import BaseOverSampler
from ..under_sampling import EditedNearestNeighbours
from ..utils import Substitution, check_target_type
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
@Substitution(
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
n_jobs=_n_jobs_docstring,
random_state=_random_state_docstring,
)
class SMOTEENN(BaseSampler):
"""Over-sampling using SMOTE and cleaning using ENN.
Combine over- and under-sampling using SMOTE and Edited Nearest Neighbours.
Read more in the :ref:`User Guide <combine>`.
Parameters
----------
{sampling_strategy}
{random_state}
smote : sampler object, default=None
The :class:`~imblearn.over_sampling.SMOTE` object to use. If not given,
a :class:`~imblearn.over_sampling.SMOTE` object with default parameters
will be given.
enn : sampler object, default=None
The :class:`~imblearn.under_sampling.EditedNearestNeighbours` object
to use. If not given, a
:class:`~imblearn.under_sampling.EditedNearestNeighbours` object with
sampling strategy='all' will be given.
{n_jobs}
Attributes
----------
sampling_strategy_ : dict
Dictionary containing the information to sample the dataset. The keys
corresponds to the class labels from which to sample and the values
are the number of samples to sample.
smote_ : sampler object
The validated :class:`~imblearn.over_sampling.SMOTE` instance.
enn_ : sampler object
The validated :class:`~imblearn.under_sampling.EditedNearestNeighbours`
instance.
n_features_in_ : int
Number of features in the input dataset.
.. versionadded:: 0.9
feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during `fit`. Defined only when `X` has feature
names that are all strings.
.. versionadded:: 0.10
See Also
--------
SMOTETomek : Over-sample using SMOTE followed by under-sampling removing
the Tomek's links.
Notes
-----
The method is presented in [1]_.
Supports multi-class resampling. Refer to SMOTE and ENN regarding the
scheme which used.
References
----------
.. [1] G. Batista, R. C. Prati, M. C. Monard. "A study of the behavior of
several methods for balancing machine learning training data," ACM
Sigkdd Explorations Newsletter 6 (1), 20-29, 2004.
Examples
--------
>>> from collections import Counter
>>> from sklearn.datasets import make_classification
>>> from imblearn.combine import SMOTEENN
>>> X, y = make_classification(n_classes=2, class_sep=2,
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
>>> print('Original dataset shape %s' % Counter(y))
Original dataset shape Counter({{1: 900, 0: 100}})
>>> sme = SMOTEENN(random_state=42)
>>> X_res, y_res = sme.fit_resample(X, y)
>>> print('Resampled dataset shape %s' % Counter(y_res))
Resampled dataset shape Counter({{0: 900, 1: 881}})
"""
_sampling_type = "over-sampling"
_parameter_constraints: dict = {
**BaseOverSampler._parameter_constraints,
"smote": [SMOTE, None],
"enn": [EditedNearestNeighbours, None],
"n_jobs": [numbers.Integral, None],
}
def __init__(
self,
*,
sampling_strategy="auto",
random_state=None,
smote=None,
enn=None,
n_jobs=None,
):
super().__init__()
self.sampling_strategy = sampling_strategy
self.random_state = random_state
self.smote = smote
self.enn = enn
self.n_jobs = n_jobs
def _validate_estimator(self):
"Private function to validate SMOTE and ENN objects"
if self.smote is not None:
self.smote_ = clone(self.smote)
else:
self.smote_ = SMOTE(
sampling_strategy=self.sampling_strategy,
random_state=self.random_state,
)
if self.enn is not None:
self.enn_ = clone(self.enn)
else:
self.enn_ = EditedNearestNeighbours(
sampling_strategy="all", n_jobs=self.n_jobs
)
def _fit_resample(self, X, y):
self._validate_estimator()
y = check_target_type(y)
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
self.sampling_strategy_ = self.sampling_strategy
X_res, y_res = self.smote_.fit_resample(X, y)
return self.enn_.fit_resample(X_res, y_res)