diff --git a/.gitignore b/.gitignore index ecdba127..b23663b6 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,7 @@ nosetests.xml coverage.xml *,cover .hypothesis/ +*.swp # Translations *.mo diff --git a/doc/api.rst b/doc/api.rst index 86b8d333..c2b8215f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -4,6 +4,17 @@ scikit-learn-extra API .. currentmodule:: sklearn_extra + +Feature weighting +================= + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + feature_weighting.TfigmTransformer + + Kernel approximation ==================== diff --git a/examples/feature_weighting/plot_tfigm_text.py b/examples/feature_weighting/plot_tfigm_text.py new file mode 100644 index 00000000..5d7c9c44 --- /dev/null +++ b/examples/feature_weighting/plot_tfigm_text.py @@ -0,0 +1,66 @@ +# License: BSD 3 clause +# +# Authors: Roman Yurchak <rth.yurchak@gmail.com> +import os + +import pandas as pd + +from sklearn.svm import LinearSVC +from sklearn.preprocessing import Normalizer, FunctionTransformer +from sklearn.pipeline import make_pipeline +from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer +from sklearn.datasets import fetch_20newsgroups +from sklearn.model_selection import cross_validate +from sklearn.metrics import f1_score + +from sklearn_extra.feature_weighting import TfigmTransformer + +if "CI" in os.environ: + # make this example run faster in CI + categories = ["sci.crypt", "comp.graphics", "comp.sys.mac.hardware"] +else: + categories = None + +docs, y = fetch_20newsgroups(return_X_y=True, categories=categories) + + +vect = CountVectorizer(min_df=5, stop_words="english", ngram_range=(1, 1)) +X = vect.fit_transform(docs) + +res = [] + +for scaler_label, scaler in [ + ("TF", FunctionTransformer(lambda x: x)), + ("TF-IDF(sublinear_tf=False)", TfidfTransformer()), + ("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)), + ("TF-IGM(tf_scale=None)", TfigmTransformer()), + ("TF-IGM(tf_scale='sqrt')", TfigmTransformer(tf_scale="sqrt"),), + ("TF-IGM(tf_scale='log1p')", TfigmTransformer(tf_scale="log1p"),), +]: + pipe = make_pipeline(scaler, Normalizer()) + X_tr = pipe.fit_transform(X, y) + est = LinearSVC() + scoring = { + "F1-macro": lambda est, X, y: f1_score( + y, est.predict(X), average="macro" + ), + "balanced_accuracy": "balanced_accuracy", + } + scores = cross_validate(est, X_tr, y, scoring=scoring,) + for key, val in scores.items(): + if not key.endswith("_time"): + res.append( + { + "metric": "_".join(key.split("_")[1:]), + "subset": key.split("_")[0], + "preprocessing": scaler_label, + "score": "{:.3f}±{:.3f}".format(val.mean(), val.std()), + } + ) +scores = ( + pd.DataFrame(res) + .set_index(["preprocessing", "metric", "subset"])["score"] + .unstack(-1) +) +scores = scores["test"].unstack(-1).sort_values("F1-macro", ascending=False) +print(scores) diff --git a/sklearn_extra/feature_weighting/__init__.py b/sklearn_extra/feature_weighting/__init__.py new file mode 100644 index 00000000..a87491c5 --- /dev/null +++ b/sklearn_extra/feature_weighting/__init__.py @@ -0,0 +1,5 @@ +# License: BSD 3 clause + +from ._text import TfigmTransformer + +__all__ = ["TfigmTransformer"] diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py new file mode 100644 index 00000000..58a019e4 --- /dev/null +++ b/sklearn_extra/feature_weighting/_text.py @@ -0,0 +1,203 @@ +# License: BSD 3 clause +# +# Authors: Roman Yurchak <rth.yurchak@gmail.com> + +import numpy as np +import scipy.sparse as sp + +from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils.validation import check_array, check_X_y +from sklearn.preprocessing import LabelEncoder + + +class TfigmTransformer(BaseEstimator, TransformerMixin): + """TF-IGM feature weighting. + + TF-IGM (Inverse Gravity Momentum) is a supervised + feature weighting scheme for classification tasks that measures + class distinguishing power. + + See User Guide for mode details. + + Parameters + ---------- + alpha : float, default=0.15 + regularization parameter. Known good default values are 0.14 - 0.20. + tf_scale : {"sqrt", "log1p"}, default=None + if not None, scaling applied to term frequency. Possible scaling values are, + - "sqrt": square root scaling + - "log1p": ``log(1 + tf)`` scaling. This option corresponds to + ``sublinear_tf=True`` parameter in + :class:`~sklearn.feature_extraction.text.TfidfTransformer`. + + Attributes + ---------- + igm_ : array of shape (n_features,) + The Inverse Gravity Moment (IGM) weight. + coef_ : array of shape (n_features,) + Regularized IGM weight corresponding to ``alpha + igm_`` + + Examples + -------- + >>> from sklearn.feature_extraction.text import CountVectorizer + >>> from sklearn.pipeline import Pipeline + >>> from sklearn_extra.feature_weighting import TfigmTransformer + >>> corpus = ['this is the first document', + ... 'this document is the second document', + ... 'and this is the third one', + ... 'is this the first document'] + >>> y = ['1', '2', '1', '2'] + >>> pipe = Pipeline([('count', CountVectorizer()), + ... ('tfigm', TfigmTransformer())]).fit(corpus, y) + >>> pipe['count'].transform(corpus).toarray() + array([[0, 1, 1, 1, 0, 0, 1, 0, 1], + [0, 2, 0, 1, 0, 1, 1, 0, 1], + [1, 0, 0, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 0, 1, 0, 1]]) + >>> pipe['tfigm'].igm_ + array([1. , 0.25, 0. , 0. , 1. , 1. , 0. , 1. , 0. ]) + >>> pipe['tfigm'].coef_ + array([1.15, 0.4 , 0.15, 0.15, 1.15, 1.15, 0.15, 1.15, 0.15]) + >>> pipe.transform(corpus).shape + (4, 9) + + References + ---------- + Chen, Kewen, et al. "Turning from TF-IDF to TF-IGM for term weighting + in text classification." Expert Systems with Applications 66 (2016): + 245-260. + """ + + def __init__(self, alpha=0.15, tf_scale=None): + self.alpha = alpha + self.tf_scale = tf_scale + + def _fit(self, X, y): + """Learn the igm vector (global term weights) + + Parameters + ---------- + X : {array-like, sparse matrix} of (n_samples, n_features) + a matrix of term/token counts + y : array-like of shape (n_samples,) + target classes + """ + tf_scale_map = {None: None, "sqrt": np.sqrt, "log1p": np.log1p} + + if self.tf_scale not in tf_scale_map: + raise ValueError( + "tf_scale={} should be one of {}.".format( + self.tf_scale, list(tf_scale_map) + ) + ) + self._tf_scale_func = tf_scale_map[self.tf_scale] + + if not isinstance(self.alpha, float) or self.alpha < 0: + raise ValueError( + "alpha={} must be a positive number.".format(self.alpha) + ) + + self._le = LabelEncoder().fit(y) + n_class = len(self._le.classes_) + class_freq = np.zeros((n_class, X.shape[1])) + + X_nz = X != 0 + if sp.issparse(X_nz) and X_nz.getformat() != "csr": + X_nz = X_nz.asformat("csr") + + for idx, class_label in enumerate(self._le.classes_): + y_mask = y == class_label + n_samples = y_mask.sum() + class_freq[idx, :] = X_nz[y_mask].sum(axis=0) / n_samples + + self._class_freq = class_freq + class_freq_sort = np.sort(self._class_freq, axis=0) + f1 = class_freq_sort[-1, :] + + fk = (class_freq_sort * np.arange(n_class, 0, -1)[:, None]).sum(axis=0) + # avoid division by zero + igm = np.divide(f1, fk, out=np.zeros_like(f1), where=(fk != 0)) + if n_class > 1: + # although Chen et al. paper states that it is not mandatory, we + # allways re-scale weights to [0, 1], otherwise with 2 classes + # we would get a minimal IGM value of 1/3. + self.igm_ = ((1 + n_class) * n_class * igm - 2) / ( + (1 + n_class) * n_class - 2 + ) + else: + self.igm_ = igm + # In the Chen et al. paper the regularization parameter is defined + # as 1/alpha. + self.coef_ = self.alpha + self.igm_ + return self + + def fit(self, X, y): + """Learn the igm vector (global term weights) + + Parameters + ---------- + X : {array-like, sparse matrix} of (n_samples, n_features) + a matrix of term/token counts + y : array-like of shape (n_samples,) + target classes + """ + X, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) + self._fit(X, y) + return self + + def _transform(self, X): + """Transform a count matrix to a TF-IGM representation + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + a matrix of term/token counts + + Returns + ------- + vectors : {ndarray, sparse matrix} of shape (n_samples, n_features) + transformed matrix + """ + if self._tf_scale_func is not None: + X = self._tf_scale_func(X) + + if sp.issparse(X): + X_tr = X @ sp.diags(self.coef_) + else: + X_tr = X * self.coef_[None, :] + return X_tr + + def transform(self, X): + """Transform a count matrix to a TF-IGM representation + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + a matrix of term/token counts + + Returns + ------- + vectors : {ndarray, sparse matrix} of shape (n_samples, n_features) + transformed matrix + """ + X = check_array(X, accept_sparse=["csr", "csc"]) + X_tr = self._transform(X) + return X_tr + + def fit_transform(self, X, y): + """Transform a count matrix to a TF-IGM representation + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + a matrix of term/token counts + + Returns + ------- + vectors : {ndarray, sparse matrix} of shape (n_samples, n_features) + transformed matrix + """ + X, y = check_X_y(X, y, accept_sparse=["csr", "csc"]) + self._fit(X, y) + X_tr = self._transform(X) + return X_tr diff --git a/sklearn_extra/feature_weighting/tests/test_text.py b/sklearn_extra/feature_weighting/tests/test_text.py new file mode 100644 index 00000000..412a67c1 --- /dev/null +++ b/sklearn_extra/feature_weighting/tests/test_text.py @@ -0,0 +1,97 @@ +# License: BSD 3 clause + +import numpy as np +from numpy.testing import assert_allclose, assert_array_less +import scipy.sparse as sp + +import pytest + +from sklearn_extra.feature_weighting import TfigmTransformer +from sklearn.datasets import make_classification + + +@pytest.mark.parametrize("array_format", ["dense", "csr", "csc", "coo"]) +def test_tfigm_transform(array_format): + X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]]) + if array_format != "dense": + X = sp.csr_matrix(X).asformat(array_format) + y = np.array(["a", "b", "a", "c"]) + + alpha = 0.2 + est = TfigmTransformer(alpha=alpha) + X_tr = est.fit_transform(X, y) + + assert_allclose(est.igm_, [0.20, 0.40, 0.0]) + assert_allclose(est.igm_ + alpha, est.coef_) + + assert X_tr.shape == X.shape + assert sp.issparse(X_tr) is (array_format != "dense") + + if array_format == "dense": + assert_allclose(X * est.coef_[None, :], X_tr) + else: + assert_allclose(X.A * est.coef_[None, :], X_tr.A) + + +def test_tfigm_synthetic(): + X, y = make_classification( + n_samples=100, + n_features=10, + n_informative=5, + n_redundant=0, + random_state=0, + n_classes=5, + shuffle=False, + ) + X = (X > 0).astype(np.float) + + est = TfigmTransformer() + est.fit(X, y) + # informative features have higher IGM weights than noisy ones. + # (athough here we lose a lot of information due to thresholding of X). + assert est.igm_[:5].mean() / est.igm_[5:].mean() > 3 + + +@pytest.mark.parametrize("n_class", [2, 5]) +def test_tfigm_random_distribution(n_class): + rng = np.random.RandomState(0) + n_samples, n_features = 500, 4 + X = rng.randint(2, size=(n_samples, n_features)) + y = rng.randint(n_class, size=(n_samples,)) + + est = TfigmTransformer() + X_tr = est.fit_transform(X, y) + + # all weighs are strictly positive + assert_array_less(0, est.igm_) + # and close to zero, since none of the features are discriminant + assert_array_less(est.igm_, 0.05) + + +def test_tfigm_valid_target(): + X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]]) + y = None + + est = TfigmTransformer() + with pytest.raises(ValueError, match="y cannot be None"): + est.fit(X, y) + + # check asymptotic behaviour for 1 class + y = [1, 1, 1, 1] + est = TfigmTransformer() + est.fit(X, y) + assert_allclose(est.igm_, np.ones(3)) + + +def test_tfigm_valid_target(): + X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]]) + y = [1, 1, 2, 2] + + est = TfigmTransformer(alpha=-1) + with pytest.raises(ValueError, match="alpha=-1 must be a positive number"): + est.fit(X, y) + + est = TfigmTransformer(tf_scale="unknown") + msg = r"tf_scale=unknown should be one of \[None, 'sqrt'" + with pytest.raises(ValueError, match=msg): + est.fit(X, y) diff --git a/sklearn_extra/tests/test_common.py b/sklearn_extra/tests/test_common.py index 587b8249..7cd67610 100644 --- a/sklearn_extra/tests/test_common.py +++ b/sklearn_extra/tests/test_common.py @@ -4,8 +4,15 @@ from sklearn_extra.kernel_approximation import Fastfood from sklearn_extra.kernel_methods import EigenProClassifier, EigenProRegressor from sklearn_extra.cluster import KMedoids +from sklearn_extra.feature_weighting import TfigmTransformer -ALL_ESTIMATORS = [Fastfood, KMedoids, EigenProClassifier, EigenProRegressor] +ALL_ESTIMATORS = [ + Fastfood, + KMedoids, + EigenProClassifier, + EigenProRegressor, + TfigmTransformer, +] if hasattr(estimator_checks, "parametrize_with_checks"): # Common tests are only run on scikit-learn 0.22+