Skip to content
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ Estimators
- :class:`decomposition.PCA` (with `svd_solver="full"`,
`svd_solver="randomized"` and `power_iteration_normalizer="QR"`)
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
- :class:`preprocessing.MaxAbsScaler`
- :class:`preprocessing.MinMaxScaler`

Tools
Expand Down
6 changes: 3 additions & 3 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Changelog
`full` and `randomized` solvers (with QR power iterations). See
:ref:`array_api` for more details.
:pr:`26315` and :pr:`27098` by :user:`Mateusz Sokół <mtsokol>`,
:user:`Olivier Grisel <ogrisel>` and :user:` Edoardo Abati <EdAbati>`.
:user:`Olivier Grisel <ogrisel>` and :user:`Edoardo Abati <EdAbati>`.

- |Enhancement| :func:`decomposition.non_negative_factorization`, :class:`decomposition.NMF`,
and :class:`decomposition.MiniBatchNMF` now support :class:`scipy.sparse.sparray`
Expand Down Expand Up @@ -205,11 +205,11 @@ Changelog
when `sparse_output=True` and the output is configured to be pandas.
:pr:`26931` by `Thomas Fan`_.

- |MajorFeature| :class:`preprocessing.MinMaxScaler` now
- |MajorFeature| :class:`preprocessing.MinMaxScaler` and :class:`preprocessing.MaxAbsScaler` now
supports the `Array API <https://data-apis.org/array-api/latest/>`_. Array API
support is considered experimental and might evolve without being subject to
our usual rolling deprecation cycle policy. See
:ref:`array_api` for more details. :pr:`26243` by `Tim Head`_.
:ref:`array_api` for more details. :pr:`26243` by `Tim Head`_ and :pr:`27110` by :user:`Edoardo Abati <EdAbati>`.

:mod:`sklearn.tree`
...................
Expand Down
18 changes: 13 additions & 5 deletions sklearn/preprocessing/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,25 +1215,27 @@ def partial_fit(self, X, y=None):
self : object
Fitted scaler.
"""
xp, _ = get_namespace(X)

first_pass = not hasattr(self, "n_samples_seen_")
X = self._validate_data(
X,
reset=first_pass,
accept_sparse=("csr", "csc"),
dtype=FLOAT_DTYPES,
dtype=_array_api.supported_float_dtypes(xp),
force_all_finite="allow-nan",
)

if sparse.issparse(X):
mins, maxs = min_max_axis(X, axis=0, ignore_nan=True)
max_abs = np.maximum(np.abs(mins), np.abs(maxs))
else:
max_abs = np.nanmax(np.abs(X), axis=0)
max_abs = _array_api._nanmax(xp.abs(X), axis=0)

if first_pass:
self.n_samples_seen_ = X.shape[0]
else:
max_abs = np.maximum(self.max_abs_, max_abs)
max_abs = xp.maximum(self.max_abs_, max_abs)
self.n_samples_seen_ += X.shape[0]

self.max_abs_ = max_abs
Expand All @@ -1254,12 +1256,15 @@ def transform(self, X):
Transformed array.
"""
check_is_fitted(self)

xp, _ = get_namespace(X)

X = self._validate_data(
X,
accept_sparse=("csr", "csc"),
copy=self.copy,
reset=False,
dtype=FLOAT_DTYPES,
dtype=_array_api.supported_float_dtypes(xp),
force_all_finite="allow-nan",
)

Expand All @@ -1283,11 +1288,14 @@ def inverse_transform(self, X):
Transformed array.
"""
check_is_fitted(self)

xp, _ = get_namespace(X)

X = check_array(
X,
accept_sparse=("csr", "csc"),
copy=self.copy,
dtype=FLOAT_DTYPES,
dtype=_array_api.supported_float_dtypes(xp),
force_all_finite="allow-nan",
)

Expand Down
6 changes: 2 additions & 4 deletions sklearn/preprocessing/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,12 +701,10 @@ def test_standard_check_array_of_inverse_transform():
)
@pytest.mark.parametrize(
"estimator",
[MinMaxScaler()],
[MaxAbsScaler(), MinMaxScaler()],
ids=_get_check_estimator_ids,
)
def test_minmaxscaler_array_api_compliance(
estimator, check, array_namespace, device, dtype
):
def test_scaler_array_api_compliance(estimator, check, array_namespace, device, dtype):
name = estimator.__class__.__name__
check(name, estimator, array_namespace, device=device, dtype=dtype)

Expand Down
4 changes: 2 additions & 2 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def _nanmin(X, axis=None):

else:
mask = xp.isnan(X)
X = xp.min(xp.where(mask, xp.asarray(+xp.inf), X), axis=axis)
X = xp.min(xp.where(mask, xp.asarray(+xp.inf, device=device(X)), X), axis=axis)
# Replace Infs from all NaN slices with NaN again
mask = xp.all(mask, axis=axis)
if xp.any(mask):
Expand All @@ -512,7 +512,7 @@ def _nanmax(X, axis=None):

else:
mask = xp.isnan(X)
X = xp.max(xp.where(mask, xp.asarray(-xp.inf), X), axis=axis)
X = xp.max(xp.where(mask, xp.asarray(-xp.inf, device=device(X)), X), axis=axis)
# Replace Infs from all NaN slices with NaN again
mask = xp.all(mask, axis=axis)
if xp.any(mask):
Expand Down