-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathplot_shrinkage_effect.py
124 lines (104 loc) · 3.86 KB
/
plot_shrinkage_effect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
======================================================
Effect of the shrinkage factor in random over-sampling
======================================================
This example shows the effect of the shrinkage factor used to generate the
smoothed bootstrap using the
:class:`~imblearn.over_sampling.RandomOverSampler`.
"""
# Authors: Guillaume Lemaitre <[email protected]>
# License: MIT
# %%
print(__doc__)
import seaborn as sns
sns.set_context("poster")
# %%
# First, we will generate a toy classification dataset with only few samples.
# The ratio between the classes will be imbalanced.
from collections import Counter
from sklearn.datasets import make_classification
X, y = make_classification(
n_samples=100,
n_features=2,
n_redundant=0,
weights=[0.1, 0.9],
random_state=0,
)
Counter(y)
# %%
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(7, 7))
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.4)
class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
ax.add_artist(class_legend)
ax.set_xlabel("Feature #1")
_ = ax.set_ylabel("Feature #2")
plt.tight_layout()
# %%
# Now, we will use a :class:`~imblearn.over_sampling.RandomOverSampler` to
# generate a bootstrap for the minority class with as many samples as in the
# majority class.
from imblearn.over_sampling import RandomOverSampler
sampler = RandomOverSampler(random_state=0)
X_res, y_res = sampler.fit_resample(X, y)
Counter(y_res)
# %%
fig, ax = plt.subplots(figsize=(7, 7))
scatter = plt.scatter(X_res[:, 0], X_res[:, 1], c=y_res, alpha=0.4)
class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
ax.add_artist(class_legend)
ax.set_xlabel("Feature #1")
_ = ax.set_ylabel("Feature #2")
plt.tight_layout()
# %%
# We observe that the minority samples are less transparent than the samples
# from the majority class. Indeed, it is due to the fact that these samples
# of the minority class are repeated during the bootstrap generation.
#
# We can set `shrinkage` to a floating value to add a small perturbation to the
# samples created and therefore create a smoothed bootstrap.
sampler = RandomOverSampler(shrinkage=1, random_state=0)
X_res, y_res = sampler.fit_resample(X, y)
Counter(y_res)
# %%
fig, ax = plt.subplots(figsize=(7, 7))
scatter = plt.scatter(X_res[:, 0], X_res[:, 1], c=y_res, alpha=0.4)
class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
ax.add_artist(class_legend)
ax.set_xlabel("Feature #1")
_ = ax.set_ylabel("Feature #2")
plt.tight_layout()
# %%
# In this case, we see that the samples in the minority class are not
# overlapping anymore due to the added noise.
#
# The parameter `shrinkage` allows to add more or less perturbation. Let's
# add more perturbation when generating the smoothed bootstrap.
sampler = RandomOverSampler(shrinkage=3, random_state=0)
X_res, y_res = sampler.fit_resample(X, y)
Counter(y_res)
# %%
fig, ax = plt.subplots(figsize=(7, 7))
scatter = plt.scatter(X_res[:, 0], X_res[:, 1], c=y_res, alpha=0.4)
class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
ax.add_artist(class_legend)
ax.set_xlabel("Feature #1")
_ = ax.set_ylabel("Feature #2")
plt.tight_layout()
# %%
# Increasing the value of `shrinkage` will disperse the new samples. Forcing
# the shrinkage to 0 will be equivalent to generating a normal bootstrap.
sampler = RandomOverSampler(shrinkage=0, random_state=0)
X_res, y_res = sampler.fit_resample(X, y)
Counter(y_res)
# %%
fig, ax = plt.subplots(figsize=(7, 7))
scatter = plt.scatter(X_res[:, 0], X_res[:, 1], c=y_res, alpha=0.4)
class_legend = ax.legend(*scatter.legend_elements(), loc="lower left", title="Classes")
ax.add_artist(class_legend)
ax.set_xlabel("Feature #1")
_ = ax.set_ylabel("Feature #2")
plt.tight_layout()
# %%
# Therefore, the `shrinkage` is handy to manually tune the dispersion of the
# new samples.