diff --git a/.gitignore b/.gitignore
index ecdba127..b23663b6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -57,6 +57,7 @@ nosetests.xml
 coverage.xml
 *,cover
 .hypothesis/
+*.swp
 
 # Translations
 *.mo
diff --git a/doc/api.rst b/doc/api.rst
index 86b8d333..c2b8215f 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -4,6 +4,17 @@ scikit-learn-extra API
 
 .. currentmodule:: sklearn_extra
 
+
+Feature weighting
+=================
+
+.. autosummary::
+   :toctree: generated/
+   :template: class.rst
+
+   feature_weighting.TfigmTransformer
+
+
 Kernel approximation
 ====================
 
diff --git a/examples/feature_weighting/plot_tfigm_text.py b/examples/feature_weighting/plot_tfigm_text.py
new file mode 100644
index 00000000..5d7c9c44
--- /dev/null
+++ b/examples/feature_weighting/plot_tfigm_text.py
@@ -0,0 +1,66 @@
+# License: BSD 3 clause
+#
+# Authors: Roman Yurchak <rth.yurchak@gmail.com>
+import os
+
+import pandas as pd
+
+from sklearn.svm import LinearSVC
+from sklearn.preprocessing import Normalizer, FunctionTransformer
+from sklearn.pipeline import make_pipeline
+from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
+from sklearn.datasets import fetch_20newsgroups
+from sklearn.model_selection import cross_validate
+from sklearn.metrics import f1_score
+
+from sklearn_extra.feature_weighting import TfigmTransformer
+
+if "CI" in os.environ:
+    # make this example run faster in CI
+    categories = ["sci.crypt", "comp.graphics", "comp.sys.mac.hardware"]
+else:
+    categories = None
+
+docs, y = fetch_20newsgroups(return_X_y=True, categories=categories)
+
+
+vect = CountVectorizer(min_df=5, stop_words="english", ngram_range=(1, 1))
+X = vect.fit_transform(docs)
+
+res = []
+
+for scaler_label, scaler in [
+    ("TF", FunctionTransformer(lambda x: x)),
+    ("TF-IDF(sublinear_tf=False)", TfidfTransformer()),
+    ("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)),
+    ("TF-IGM(tf_scale=None)", TfigmTransformer()),
+    ("TF-IGM(tf_scale='sqrt')", TfigmTransformer(tf_scale="sqrt"),),
+    ("TF-IGM(tf_scale='log1p')", TfigmTransformer(tf_scale="log1p"),),
+]:
+    pipe = make_pipeline(scaler, Normalizer())
+    X_tr = pipe.fit_transform(X, y)
+    est = LinearSVC()
+    scoring = {
+        "F1-macro": lambda est, X, y: f1_score(
+            y, est.predict(X), average="macro"
+        ),
+        "balanced_accuracy": "balanced_accuracy",
+    }
+    scores = cross_validate(est, X_tr, y, scoring=scoring,)
+    for key, val in scores.items():
+        if not key.endswith("_time"):
+            res.append(
+                {
+                    "metric": "_".join(key.split("_")[1:]),
+                    "subset": key.split("_")[0],
+                    "preprocessing": scaler_label,
+                    "score": "{:.3f}±{:.3f}".format(val.mean(), val.std()),
+                }
+            )
+scores = (
+    pd.DataFrame(res)
+    .set_index(["preprocessing", "metric", "subset"])["score"]
+    .unstack(-1)
+)
+scores = scores["test"].unstack(-1).sort_values("F1-macro", ascending=False)
+print(scores)
diff --git a/sklearn_extra/feature_weighting/__init__.py b/sklearn_extra/feature_weighting/__init__.py
new file mode 100644
index 00000000..a87491c5
--- /dev/null
+++ b/sklearn_extra/feature_weighting/__init__.py
@@ -0,0 +1,5 @@
+# License: BSD 3 clause
+
+from ._text import TfigmTransformer
+
+__all__ = ["TfigmTransformer"]
diff --git a/sklearn_extra/feature_weighting/_text.py b/sklearn_extra/feature_weighting/_text.py
new file mode 100644
index 00000000..58a019e4
--- /dev/null
+++ b/sklearn_extra/feature_weighting/_text.py
@@ -0,0 +1,203 @@
+# License: BSD 3 clause
+#
+# Authors: Roman Yurchak <rth.yurchak@gmail.com>
+
+import numpy as np
+import scipy.sparse as sp
+
+from sklearn.base import BaseEstimator, TransformerMixin
+from sklearn.utils.validation import check_array, check_X_y
+from sklearn.preprocessing import LabelEncoder
+
+
+class TfigmTransformer(BaseEstimator, TransformerMixin):
+    """TF-IGM feature weighting.
+
+    TF-IGM (Inverse Gravity Momentum) is a supervised
+    feature weighting scheme for classification tasks that measures
+    class distinguishing power.
+    
+    See User Guide for mode details.
+
+    Parameters
+    ----------
+    alpha : float, default=0.15
+      regularization parameter. Known good default values are 0.14 - 0.20.
+    tf_scale : {"sqrt", "log1p"}, default=None
+      if not None, scaling applied to term frequency. Possible scaling values are,
+       - "sqrt":  square root scaling
+       - "log1p": ``log(1 + tf)`` scaling. This option corresponds to
+       ``sublinear_tf=True`` parameter in
+       :class:`~sklearn.feature_extraction.text.TfidfTransformer`.
+
+    Attributes
+    ----------
+    igm_ : array of shape (n_features,)
+        The Inverse Gravity Moment (IGM) weight.
+    coef_ : array of shape (n_features,)
+        Regularized IGM weight corresponding to ``alpha + igm_``
+
+    Examples
+    --------
+    >>> from sklearn.feature_extraction.text import CountVectorizer
+    >>> from sklearn.pipeline import Pipeline
+    >>> from sklearn_extra.feature_weighting import TfigmTransformer
+    >>> corpus = ['this is the first document',
+    ...           'this document is the second document',
+    ...           'and this is the third one',
+    ...           'is this the first document']
+    >>> y = ['1', '2', '1', '2']
+    >>> pipe = Pipeline([('count', CountVectorizer()),
+    ...                  ('tfigm', TfigmTransformer())]).fit(corpus, y)
+    >>> pipe['count'].transform(corpus).toarray()
+    array([[0, 1, 1, 1, 0, 0, 1, 0, 1],
+           [0, 2, 0, 1, 0, 1, 1, 0, 1],
+           [1, 0, 0, 1, 1, 0, 1, 1, 1],
+           [0, 1, 1, 1, 0, 0, 1, 0, 1]])
+    >>> pipe['tfigm'].igm_
+    array([1.  , 0.25, 0.  , 0.  , 1.  , 1.  , 0.  , 1.  , 0.  ])
+    >>> pipe['tfigm'].coef_
+    array([1.15, 0.4 , 0.15, 0.15, 1.15, 1.15, 0.15, 1.15, 0.15])
+    >>> pipe.transform(corpus).shape
+    (4, 9)
+
+    References
+    ----------
+    Chen, Kewen, et al. "Turning from TF-IDF to TF-IGM for term weighting
+    in text classification." Expert Systems with Applications 66 (2016):
+    245-260.
+    """
+
+    def __init__(self, alpha=0.15, tf_scale=None):
+        self.alpha = alpha
+        self.tf_scale = tf_scale
+
+    def _fit(self, X, y):
+        """Learn the igm vector (global term weights)
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix} of (n_samples, n_features)
+            a matrix of term/token counts
+        y : array-like of shape (n_samples,)
+            target classes
+        """
+        tf_scale_map = {None: None, "sqrt": np.sqrt, "log1p": np.log1p}
+
+        if self.tf_scale not in tf_scale_map:
+            raise ValueError(
+                "tf_scale={} should be one of {}.".format(
+                    self.tf_scale, list(tf_scale_map)
+                )
+            )
+        self._tf_scale_func = tf_scale_map[self.tf_scale]
+
+        if not isinstance(self.alpha, float) or self.alpha < 0:
+            raise ValueError(
+                "alpha={} must be a positive number.".format(self.alpha)
+            )
+
+        self._le = LabelEncoder().fit(y)
+        n_class = len(self._le.classes_)
+        class_freq = np.zeros((n_class, X.shape[1]))
+
+        X_nz = X != 0
+        if sp.issparse(X_nz) and X_nz.getformat() != "csr":
+            X_nz = X_nz.asformat("csr")
+
+        for idx, class_label in enumerate(self._le.classes_):
+            y_mask = y == class_label
+            n_samples = y_mask.sum()
+            class_freq[idx, :] = X_nz[y_mask].sum(axis=0) / n_samples
+
+        self._class_freq = class_freq
+        class_freq_sort = np.sort(self._class_freq, axis=0)
+        f1 = class_freq_sort[-1, :]
+
+        fk = (class_freq_sort * np.arange(n_class, 0, -1)[:, None]).sum(axis=0)
+        # avoid division by zero
+        igm = np.divide(f1, fk, out=np.zeros_like(f1), where=(fk != 0))
+        if n_class > 1:
+            # although Chen et al. paper states that it is not mandatory, we
+            # allways re-scale weights to [0, 1], otherwise with 2 classes
+            # we would get a minimal IGM value of 1/3.
+            self.igm_ = ((1 + n_class) * n_class * igm - 2) / (
+                (1 + n_class) * n_class - 2
+            )
+        else:
+            self.igm_ = igm
+        # In the Chen et al. paper the regularization parameter is defined
+        # as 1/alpha.
+        self.coef_ = self.alpha + self.igm_
+        return self
+
+    def fit(self, X, y):
+        """Learn the igm vector (global term weights)
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix} of (n_samples, n_features)
+            a matrix of term/token counts
+        y : array-like of shape (n_samples,)
+            target classes
+        """
+        X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
+        self._fit(X, y)
+        return self
+
+    def _transform(self, X):
+        """Transform a count matrix to a TF-IGM representation
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix} of shape (n_samples, n_features)
+            a matrix of term/token counts
+
+        Returns
+        -------
+        vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
+            transformed matrix
+        """
+        if self._tf_scale_func is not None:
+            X = self._tf_scale_func(X)
+
+        if sp.issparse(X):
+            X_tr = X @ sp.diags(self.coef_)
+        else:
+            X_tr = X * self.coef_[None, :]
+        return X_tr
+
+    def transform(self, X):
+        """Transform a count matrix to a TF-IGM representation
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix} of shape (n_samples, n_features)
+            a matrix of term/token counts
+
+        Returns
+        -------
+        vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
+            transformed matrix
+        """
+        X = check_array(X, accept_sparse=["csr", "csc"])
+        X_tr = self._transform(X)
+        return X_tr
+
+    def fit_transform(self, X, y):
+        """Transform a count matrix to a TF-IGM representation
+
+        Parameters
+        ----------
+        X : {array-like, sparse matrix} of shape (n_samples, n_features)
+            a matrix of term/token counts
+
+        Returns
+        -------
+        vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
+            transformed matrix
+        """
+        X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
+        self._fit(X, y)
+        X_tr = self._transform(X)
+        return X_tr
diff --git a/sklearn_extra/feature_weighting/tests/test_text.py b/sklearn_extra/feature_weighting/tests/test_text.py
new file mode 100644
index 00000000..412a67c1
--- /dev/null
+++ b/sklearn_extra/feature_weighting/tests/test_text.py
@@ -0,0 +1,97 @@
+# License: BSD 3 clause
+
+import numpy as np
+from numpy.testing import assert_allclose, assert_array_less
+import scipy.sparse as sp
+
+import pytest
+
+from sklearn_extra.feature_weighting import TfigmTransformer
+from sklearn.datasets import make_classification
+
+
+@pytest.mark.parametrize("array_format", ["dense", "csr", "csc", "coo"])
+def test_tfigm_transform(array_format):
+    X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]])
+    if array_format != "dense":
+        X = sp.csr_matrix(X).asformat(array_format)
+    y = np.array(["a", "b", "a", "c"])
+
+    alpha = 0.2
+    est = TfigmTransformer(alpha=alpha)
+    X_tr = est.fit_transform(X, y)
+
+    assert_allclose(est.igm_, [0.20, 0.40, 0.0])
+    assert_allclose(est.igm_ + alpha, est.coef_)
+
+    assert X_tr.shape == X.shape
+    assert sp.issparse(X_tr) is (array_format != "dense")
+
+    if array_format == "dense":
+        assert_allclose(X * est.coef_[None, :], X_tr)
+    else:
+        assert_allclose(X.A * est.coef_[None, :], X_tr.A)
+
+
+def test_tfigm_synthetic():
+    X, y = make_classification(
+        n_samples=100,
+        n_features=10,
+        n_informative=5,
+        n_redundant=0,
+        random_state=0,
+        n_classes=5,
+        shuffle=False,
+    )
+    X = (X > 0).astype(np.float)
+
+    est = TfigmTransformer()
+    est.fit(X, y)
+    # informative features have higher IGM weights than noisy ones.
+    # (athough here we lose a lot of information due to thresholding of X).
+    assert est.igm_[:5].mean() / est.igm_[5:].mean() > 3
+
+
+@pytest.mark.parametrize("n_class", [2, 5])
+def test_tfigm_random_distribution(n_class):
+    rng = np.random.RandomState(0)
+    n_samples, n_features = 500, 4
+    X = rng.randint(2, size=(n_samples, n_features))
+    y = rng.randint(n_class, size=(n_samples,))
+
+    est = TfigmTransformer()
+    X_tr = est.fit_transform(X, y)
+
+    # all weighs are strictly positive
+    assert_array_less(0, est.igm_)
+    # and close to zero, since none of the features are discriminant
+    assert_array_less(est.igm_, 0.05)
+
+
+def test_tfigm_valid_target():
+    X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]])
+    y = None
+
+    est = TfigmTransformer()
+    with pytest.raises(ValueError, match="y cannot be None"):
+        est.fit(X, y)
+
+    # check asymptotic behaviour for 1 class
+    y = [1, 1, 1, 1]
+    est = TfigmTransformer()
+    est.fit(X, y)
+    assert_allclose(est.igm_, np.ones(3))
+
+
+def test_tfigm_valid_target():
+    X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]])
+    y = [1, 1, 2, 2]
+
+    est = TfigmTransformer(alpha=-1)
+    with pytest.raises(ValueError, match="alpha=-1 must be a positive number"):
+        est.fit(X, y)
+
+    est = TfigmTransformer(tf_scale="unknown")
+    msg = r"tf_scale=unknown should be one of \[None, 'sqrt'"
+    with pytest.raises(ValueError, match=msg):
+        est.fit(X, y)
diff --git a/sklearn_extra/tests/test_common.py b/sklearn_extra/tests/test_common.py
index 587b8249..7cd67610 100644
--- a/sklearn_extra/tests/test_common.py
+++ b/sklearn_extra/tests/test_common.py
@@ -4,8 +4,15 @@
 from sklearn_extra.kernel_approximation import Fastfood
 from sklearn_extra.kernel_methods import EigenProClassifier, EigenProRegressor
 from sklearn_extra.cluster import KMedoids
+from sklearn_extra.feature_weighting import TfigmTransformer
 
-ALL_ESTIMATORS = [Fastfood, KMedoids, EigenProClassifier, EigenProRegressor]
+ALL_ESTIMATORS = [
+    Fastfood,
+    KMedoids,
+    EigenProClassifier,
+    EigenProRegressor,
+    TfigmTransformer,
+]
 
 if hasattr(estimator_checks, "parametrize_with_checks"):
     # Common tests are only run on scikit-learn 0.22+