Skip to content

Robust irls for regression #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 79 additions & 45 deletions sklearn_extra/robust/robust_weighted_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
check_consistent_length,
)
from sklearn.utils.validation import check_is_fitted
from sklearn.linear_model import SGDRegressor, SGDClassifier
from sklearn.linear_model import SGDRegressor, SGDClassifier, LinearRegression
from sklearn.multiclass import OneVsRestClassifier, OneVsOneClassifier
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics.pairwise import euclidean_distances
Expand Down Expand Up @@ -120,6 +120,9 @@ class _RobustWeightedEstimator(BaseEstimator):
If callable, the function is used as loss function ro construct
the weights.

solver : {"IRLS", "SGD"}, default="SGD"
Algorithm used for the optimization. For now only for regression.

weighting : string, default="huber"
Weighting scheme used to make the estimator robust.
Can be 'huber' for huber-type weights or 'mom' for median-of-means
Expand All @@ -146,15 +149,14 @@ class _RobustWeightedEstimator(BaseEstimator):
If None, c is estimated at each step using half the Inter-quartile
range, this tends to be conservative (robust).

k : int < sample_size/2, default=1
k : int < sample_size/2, default=None
Parameter used for mom weighting procedure, used only if weightings
is 'mom'. 2k+1 is the number of blocks used for median-of-means
estimation, higher value of k means a more robust estimator.
Can have a big effect on efficiency.
If None, k is estimated using the number of points distant from the
median of means of more than 2 times a robust estimate of the scale
(using the inter-quartile range), this tends to be conservative
(robust).
(using the inter-quartile range), this can be unstable.

tol : float or None, (default = 1e-3)
The stopping criterion. If it is not None, training will stop when
Expand Down Expand Up @@ -219,24 +221,27 @@ def __init__(
self,
base_estimator,
loss,
solver="SGD",
weighting="huber",
max_iter=100,
burn_in=10,
eta0=0.1,
c=None,
k=0,
k=None,
tol=1e-5,
n_iter_no_change=10,
verbose=0,
random_state=None,
):
self.base_estimator = base_estimator
self.weighting = weighting
self.solver = solver
self.eta0 = eta0
self.burn_in = burn_in
self.c = c
self.k = k
self.loss = loss
self.solver = solver
self.max_iter = max_iter
self.tol = tol
self.n_iter_no_change = n_iter_no_change
Expand Down Expand Up @@ -287,9 +292,9 @@ def fit(self, X, y=None):

if "n_iter_no_change" in parameters:
base_estimator.set_params(n_iter_no_change=self.n_iter_no_change)

base_estimator.set_params(random_state=random_state)
if self.burn_in > 0:
if "random_state" in parameters:
base_estimator.set_params(random_state=random_state)
if (self.burn_in > 0) and self.solver != "IRLS":
learning_rate = base_estimator.learning_rate
base_estimator.set_params(learning_rate="constant", eta0=self.eta0)

Expand All @@ -311,8 +316,11 @@ def fit(self, X, y=None):
# Initialization of the estimator
# Partial fit for the estimator to be set to "fitted" to be able
# to predict.
base_estimator.partial_fit(X, y)
# As the partial fit is here non-robust, override the
if self.solver == "SGD":
base_estimator.partial_fit(X, y)
else:
base_estimator.fit(X, y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this difference? If it's OK to call fit, then can't we call fit for all estimators? For SGD it's mostly equivalent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partial_fit only does one iteration, fit do not, I could use fit with max_iter=1 alternatively, would this be better ?
It is important to use only one iteration because there may be outliers in the data and training on the whole dataset would imply that a lot of steps are non-robust and for SGD with a decreasing step-size it may never recover. This is different for IRLS.

# As the fit is here non-robust, override the
# learned coefs.
base_estimator.coef_ = np.zeros([len(X[0])])
base_estimator.intercept_ = np.array([0])
Expand All @@ -329,7 +337,11 @@ def fit(self, X, y=None):
# Optimization algorithm
for epoch in range(self.max_iter):

if epoch > self.burn_in and self.burn_in > 0:
if (
(epoch > self.burn_in)
and (self.burn_in > 0)
and (self.solver == "SGD")
):
# If not in the burn_in phase anymore, change the learning_rate
# calibration to the one edicted by self.base_estimator.
base_estimator.set_params(learning_rate=learning_rate)
Expand Down Expand Up @@ -361,8 +373,6 @@ def fit(self, X, y=None):
# epoch using the previously computed weights. Also shuffle the data.
perm = random_state.permutation(len(X))

base_estimator.partial_fit(X, y, sample_weight=weights)

if (self.tol is not None) and (
current_loss > best_loss - self.tol
):
Expand All @@ -383,9 +393,14 @@ def fit(self, X, y=None):
X[perm], y, sample_weight=weights[perm]
)
else:
base_estimator.partial_fit(
X[perm], y[perm], sample_weight=weights[perm]
)
if self.solver == "SGD":
base_estimator.partial_fit(
X[perm], y[perm], sample_weight=weights[perm]
)
else:
# Do one IRLS step.
base_estimator.fit(X, y, sample_weight=weights)

if (self.tol is not None) and (
current_loss > best_loss - self.tol
):
Expand Down Expand Up @@ -492,10 +507,11 @@ def psisx(x):
elif self.weighting == "mom":
if self.k is None:
med = np.median(loss_values)
# scale estimator using iqr, rescaled by what would be if the
# loss was Gaussian.
scale = iqr(np.abs(loss_values - med)) / 1.37
# scale estimator using iqr
scale = iqr(np.abs(loss_values - med))
k = np.sum(np.abs(loss_values - med) > 2 * scale)
# For safety
k = min(k, 3)
else:
k = self.k
# Choose (randomly) 2k+1 (almost-)equal blocks of data.
Expand Down Expand Up @@ -638,15 +654,14 @@ class RobustWeightedClassifier(BaseEstimator, ClassifierMixin):
If None, c is estimated at each step using half the Inter-quartile
range, this tends to be conservative (robust).

k : int < sample_size/2, default=1
k : int < sample_size/2, default=None
Parameter used for mom weighting procedure, used only if weightings
is 'mom'. 2k+1 is the number of blocks used for median-of-means
estimation, higher value of k means a more robust estimator.
Can have a big effect on efficiency.
If None, k is estimated using the number of points distant from the
median of means of more than 2 times a robust estimate of the scale
(using the inter-quartile range), this tends to be conservative
(robust).
(using the inter-quartile range), this can be unstable.

loss : string, None or callable, default="log"
Classification losses supported : 'log', 'hinge', 'modified_huber'.
Expand Down Expand Up @@ -751,7 +766,7 @@ def __init__(
burn_in=10,
eta0=0.01,
c=None,
k=0,
k=None,
loss="log",
sgd_args=None,
multi_class="ovr",
Expand Down Expand Up @@ -804,6 +819,7 @@ def fit(self, X, y):
base_robust_estimator_ = _RobustWeightedEstimator(
SGDClassifier(**sgd_args, eta0=self.eta0),
weighting=self.weighting,
solver="SGD",
loss=self.loss,
burn_in=self.burn_in,
c=self.c,
Expand Down Expand Up @@ -949,6 +965,11 @@ class RobustWeightedRegressor(BaseEstimator, RegressorMixin):
Can be 'huber' for huber-type weights or 'mom' for median-of-means
type weights.

solver : {"SGD", "IRLS"}
Algorithm used for optimization. If "SGD" then, use SGDRegressor as
base estimator and reweight at each optimization step. If "IRLS" then
use multiple fit of reweighted LinearRegression with robust weights.

max_iter : int, default=100
Maximum number of iterations.
For more information, see the optimization scheme of base_estimator
Expand All @@ -970,7 +991,7 @@ class RobustWeightedRegressor(BaseEstimator, RegressorMixin):
If None, c is estimated at each step using half the Inter-quartile
range, this tends to be conservative (robust).

k : int < sample_size/2, default=1
k : int < sample_size/2 or None, default=None
Parameter used for mom weighting procedure, used only if weightings
is 'mom'. 2k+1 is the number of blocks used for median-of-means
estimation, higher value of k means a more robust estimator.
Expand Down Expand Up @@ -1061,11 +1082,12 @@ class RobustWeightedRegressor(BaseEstimator, RegressorMixin):
def __init__(
self,
weighting="huber",
solver="SGD",
max_iter=100,
burn_in=10,
eta0=0.01,
c=None,
k=0,
k=None,
loss=SQ_LOSS,
sgd_args=None,
tol=1e-3,
Expand All @@ -1075,6 +1097,7 @@ def __init__(
):

self.weighting = weighting
self.solver = solver
self.max_iter = max_iter
self.burn_in = burn_in
self.eta0 = eta0
Expand Down Expand Up @@ -1111,21 +1134,33 @@ def fit(self, X, y):
# Define the base estimator

X, y = self._validate_data(X, y, y_numeric=True)

self.base_estimator_ = _RobustWeightedEstimator(
SGDRegressor(**sgd_args, eta0=self.eta0),
weighting=self.weighting,
loss=self.loss,
burn_in=self.burn_in,
c=self.c,
k=self.k,
eta0=self.eta0,
max_iter=self.max_iter,
tol=self.tol,
n_iter_no_change=self.n_iter_no_change,
verbose=self.verbose,
random_state=self.random_state,
)
kwargs = {
"weighting": self.weighting,
"loss": self.loss,
"burn_in": self.burn_in,
"c": self.c,
"k": self.k,
"eta0": self.eta0,
"max_iter": self.max_iter,
"tol": self.tol,
"n_iter_no_change": self.n_iter_no_change,
"verbose": self.verbose,
"random_state": self.random_state,
}
if self.solver == "SGD":
self.base_estimator_ = _RobustWeightedEstimator(
SGDRegressor(**sgd_args, eta0=self.eta0),
solver="SGD",
**kwargs,
)
elif self.solver == "IRLS":
self.base_estimator_ = _RobustWeightedEstimator(
LinearRegression(),
solver="IRLS",
**kwargs,
)
else:
raise ValueError("No such solver.")
self.base_estimator_.fit(X, y)

self.weights_ = self.base_estimator_.weights_
Expand Down Expand Up @@ -1217,15 +1252,14 @@ class RobustWeightedKMeans(BaseEstimator, ClusterMixin):
If None, c is estimated at each step using half the Inter-quartile
range, this tends to be conservative (robust).

k : int < sample_size/2, default=1
k : int < sample_size/2 or None, default=None
Parameter used for mom weighting procedure, used only if weightings
is 'mom'. 2k+1 is the number of blocks used for median-of-means
estimation, higher value of k means a more robust estimator.
Can have a big effect on efficiency.
If None, k is estimated using the number of points distant from the
median of means of more than 2 times a robust estimate of the scale
(using the inter-quartile range), this tends to be conservative
(robust).
(using the inter-quartile range), this can be unstable.

kmeans_args : dict, default={}
arguments of the MiniBatchKMeans base estimator. Must not contain
Expand Down Expand Up @@ -1316,7 +1350,7 @@ def __init__(
max_iter=100,
eta0=0.01,
c=None,
k=0,
k=None,
kmeans_args=None,
tol=1e-3,
n_iter_no_change=10,
Expand Down Expand Up @@ -1367,7 +1401,7 @@ def fit(self, X, y=None):
self.n_clusters,
batch_size=X.shape[0],
random_state=self.random_state,
**kmeans_args
**kmeans_args,
),
burn_in=0, # Important because it does not mean anything to
# have burn-in
Expand Down
25 changes: 18 additions & 7 deletions sklearn_extra/robust/tests/test_robust_weighted_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
classif_losses = ["log", "hinge"]
weightings = ["huber", "mom"]
multi_class = ["ovr", "ovo"]
solvers = ["SGD", "IRLS"]


def test_robust_estimator_max_iter():
Expand Down Expand Up @@ -240,8 +241,8 @@ def test_robust_no_proba():


# Regression test with outliers
X_rc = rng.uniform(-1, 1, size=[200])
y_rc = X_rc + 0.1 * rng.normal(size=200)
X_rc = rng.uniform(-1, 1, size=[300])
y_rc = X_rc + 0.1 * rng.normal(size=300)
X_rc[0] = 10
X_rc = X_rc.reshape(-1, 1)
y_rc[0] = -1
Expand All @@ -253,19 +254,21 @@ def test_robust_no_proba():
@pytest.mark.parametrize("weighting", weightings)
@pytest.mark.parametrize("k", k_values)
@pytest.mark.parametrize("c", c_values)
def test_corrupted_regression(loss, weighting, k, c):
@pytest.mark.parametrize("solver", solvers)
def test_corrupted_regression(loss, weighting, k, c, solver):
reg = RobustWeightedRegressor(
loss=loss,
max_iter=50,
max_iter=100,
solver=solver,
weighting=weighting,
k=k,
c=c,
random_state=rng,
n_iter_no_change=20,
)
reg.fit(X_rc, y_rc)
assert np.abs(reg.coef_[0] - 1) < 0.1
assert np.abs(reg.intercept_[0]) < 0.1
assert np.abs(reg.coef_[0] - 1) < 0.2
assert np.abs(reg.intercept_) < 0.2


# Check that weights_ parameter can be used as outlier score.
Expand All @@ -283,6 +286,14 @@ def test_regression_corrupted_weights(weighting):
assert reg.weights_[0] < np.mean(reg.weights_[1:])


def test_robust_regression_estimator_unsupported_solver():
"""Test that warning message is thrown when unsupported loss."""
model = RobustWeightedRegressor(solver="invalid")
msg = "No such solver."
with pytest.raises(ValueError, match=msg):
model.fit(X_rc, y_rc)


X_r = rng.uniform(-1, 1, size=[1000])
y_r = X_r + 0.1 * rng.normal(size=1000)
X_r = X_r.reshape(-1, 1)
Expand Down Expand Up @@ -394,7 +405,7 @@ def test_not_robust_cluster(weighting):
difference = [
np.linalg.norm(pred1[i] - pred2[i]) for i in range(len(pred1))
]
assert np.mean(difference) < 1
assert np.mean(difference) < 2


def test_transform():
Expand Down