Skip to content

Commit 593ef5f

Browse files
authored
ENH: Speed up trim_zeros (numpy#16911)
* Added a benchmark for `trim_zeros()` * Improve the performance of `np.trim_zeros()` * Increase the variety of the tests Fall back to the old `np.trim_zeros()` implementation if an exception is encountered. Emit a `DeprecationWarning` in such case. * DEP,REL: Added a deprecation release note
1 parent 8f60522 commit 593ef5f

File tree

5 files changed

+146
-21
lines changed

5 files changed

+146
-21
lines changed
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from .common import Benchmark
2+
3+
import numpy as np
4+
5+
_FLOAT = np.dtype('float64')
6+
_COMPLEX = np.dtype('complex128')
7+
_INT = np.dtype('int64')
8+
_BOOL = np.dtype('bool')
9+
10+
11+
class TrimZeros(Benchmark):
12+
param_names = ["dtype", "size"]
13+
params = [
14+
[_INT, _FLOAT, _COMPLEX, _BOOL],
15+
[3000, 30_000, 300_000]
16+
]
17+
18+
def setup(self, dtype, size):
19+
n = size // 3
20+
self.array = np.hstack([
21+
np.zeros(n),
22+
np.random.uniform(size=n),
23+
np.zeros(n),
24+
]).astype(dtype)
25+
26+
def time_trim_zeros(self, dtype, size):
27+
np.trim_zeros(self.array)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
``trim_zeros`` now requires a 1D array compatible with ``ndarray.astype(bool)``
2+
-------------------------------------------------------------------------------
3+
The ``trim_zeros`` function will, in the future, require an array with the
4+
following two properties:
5+
6+
* It must be 1D.
7+
* It must be convertable into a boolean array.

numpy/core/tests/test_deprecations.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def test_deprecated(self):
615615
self.assert_deprecated(round, args=(scalar,))
616616
self.assert_deprecated(round, args=(scalar, 0))
617617
self.assert_deprecated(round, args=(scalar,), kwargs={'ndigits': 0})
618-
618+
619619
def test_not_deprecated(self):
620620
for scalar_type in self.not_deprecated_types:
621621
scalar = scalar_type(0)
@@ -706,3 +706,21 @@ def test_deprecated(self):
706706
# And when it is an assignment into a lower dimensional subarray:
707707
self.assert_deprecated(lambda: np.array([arr, [0]], dtype=np.float64))
708708
self.assert_deprecated(lambda: np.array([[0], arr], dtype=np.float64))
709+
710+
711+
class TestTrimZeros(_DeprecationTestCase):
712+
# Numpy 1.20.0, 2020-07-31
713+
@pytest.mark.parametrize("arr", [np.random.rand(10, 10).tolist(),
714+
np.random.rand(10).astype(str)])
715+
def test_deprecated(self, arr):
716+
with warnings.catch_warnings():
717+
warnings.simplefilter('error', DeprecationWarning)
718+
try:
719+
np.trim_zeros(arr)
720+
except DeprecationWarning as ex:
721+
assert_(isinstance(ex.__cause__, ValueError))
722+
else:
723+
raise AssertionError("No error raised during function call")
724+
725+
out = np.lib.function_base._trim_zeros_old(arr)
726+
assert_array_equal(arr, out)

numpy/lib/function_base.py

+59-8
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def asarray_chkfinite(a, dtype=None, order=None):
433433
By default, the data-type is inferred from the input data.
434434
order : {'C', 'F', 'A', 'K'}, optional
435435
Memory layout. 'A' and 'K' depend on the order of input array a.
436-
'C' row-major (C-style),
436+
'C' row-major (C-style),
437437
'F' column-major (Fortran-style) memory representation.
438438
'A' (any) means 'F' if `a` is Fortran contiguous, 'C' otherwise
439439
'K' (keep) preserve input order
@@ -1624,6 +1624,57 @@ def trim_zeros(filt, trim='fb'):
16241624
>>> np.trim_zeros([0, 1, 2, 0])
16251625
[1, 2]
16261626
1627+
"""
1628+
try:
1629+
return _trim_zeros_new(filt, trim)
1630+
except Exception as ex:
1631+
# Numpy 1.20.0, 2020-07-31
1632+
warning = DeprecationWarning(
1633+
"in the future trim_zeros will require a 1-D array as input "
1634+
"that is compatible with ndarray.astype(bool)"
1635+
)
1636+
warning.__cause__ = ex
1637+
warnings.warn(warning, stacklevel=3)
1638+
1639+
# Fall back to the old implementation if an exception is encountered
1640+
# Note that the same exception may or may not be raised here as well
1641+
return _trim_zeros_old(filt, trim)
1642+
1643+
1644+
def _trim_zeros_new(filt, trim='fb'):
1645+
"""Newer optimized implementation of ``trim_zeros()``."""
1646+
arr = np.asanyarray(filt).astype(bool, copy=False)
1647+
1648+
if arr.ndim != 1:
1649+
raise ValueError('trim_zeros requires an array of exactly one dimension')
1650+
elif not len(arr):
1651+
return filt
1652+
1653+
trim_upper = trim.upper()
1654+
first = last = None
1655+
1656+
if 'F' in trim_upper:
1657+
first = arr.argmax()
1658+
# If `arr[first] is False` then so are all other elements
1659+
if not arr[first]:
1660+
return filt[:0]
1661+
1662+
if 'B' in trim_upper:
1663+
last = len(arr) - arr[::-1].argmax()
1664+
# If `arr[last - 1] is False` then so are all other elements
1665+
if not arr[last - 1]:
1666+
return filt[:0]
1667+
1668+
return filt[first:last]
1669+
1670+
1671+
def _trim_zeros_old(filt, trim='fb'):
1672+
"""
1673+
Older unoptimized implementation of ``trim_zeros()``.
1674+
1675+
Used as fallback in case an exception is encountered
1676+
in ``_trim_zeros_new()``.
1677+
16271678
"""
16281679
first = 0
16291680
trim = trim.upper()
@@ -2546,11 +2597,11 @@ def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue):
25462597
for backwards compatibility with previous versions of this function. These
25472598
arguments had no effect on the return values of the function and can be
25482599
safely ignored in this and previous versions of numpy.
2549-
2600+
25502601
Examples
2551-
--------
2602+
--------
25522603
In this example we generate two random arrays, ``xarr`` and ``yarr``, and
2553-
compute the row-wise and column-wise Pearson correlation coefficients,
2604+
compute the row-wise and column-wise Pearson correlation coefficients,
25542605
``R``. Since ``rowvar`` is true by default, we first find the row-wise
25552606
Pearson correlation coefficients between the variables of ``xarr``.
25562607
@@ -2566,11 +2617,11 @@ def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue):
25662617
array([[ 1. , 0.99256089, -0.68080986],
25672618
[ 0.99256089, 1. , -0.76492172],
25682619
[-0.68080986, -0.76492172, 1. ]])
2569-
2570-
If we add another set of variables and observations ``yarr``, we can
2620+
2621+
If we add another set of variables and observations ``yarr``, we can
25712622
compute the row-wise Pearson correlation coefficients between the
25722623
variables in ``xarr`` and ``yarr``.
2573-
2624+
25742625
>>> yarr = rng.random((3, 3))
25752626
>>> yarr
25762627
array([[0.45038594, 0.37079802, 0.92676499],
@@ -2592,7 +2643,7 @@ def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue):
25922643
1. ]])
25932644
25942645
Finally if we use the option ``rowvar=False``, the columns are now
2595-
being treated as the variables and we will find the column-wise Pearson
2646+
being treated as the variables and we will find the column-wise Pearson
25962647
correlation coefficients between variables in ``xarr`` and ``yarr``.
25972648
25982649
>>> R3 = np.corrcoef(xarr, yarr, rowvar=False)

numpy/lib/tests/test_function_base.py

+34-12
Original file line numberDiff line numberDiff line change
@@ -1166,25 +1166,47 @@ def test_subclass(self):
11661166

11671167
class TestTrimZeros:
11681168

1169-
"""
1170-
Only testing for integer splits.
1169+
a = np.array([0, 0, 1, 0, 2, 3, 4, 0])
1170+
b = a.astype(float)
1171+
c = a.astype(complex)
1172+
d = np.array([None, [], 1, False, 'b', 3.0, range(4), b''], dtype=object)
11711173

1172-
"""
1174+
def values(self):
1175+
attr_names = ('a', 'b', 'c', 'd')
1176+
return (getattr(self, name) for name in attr_names)
11731177

11741178
def test_basic(self):
1175-
a = np.array([0, 0, 1, 2, 3, 4, 0])
1176-
res = trim_zeros(a)
1177-
assert_array_equal(res, np.array([1, 2, 3, 4]))
1179+
slc = np.s_[2:-1]
1180+
for arr in self.values():
1181+
res = trim_zeros(arr)
1182+
assert_array_equal(res, arr[slc])
11781183

11791184
def test_leading_skip(self):
1180-
a = np.array([0, 0, 1, 0, 2, 3, 4, 0])
1181-
res = trim_zeros(a)
1182-
assert_array_equal(res, np.array([1, 0, 2, 3, 4]))
1185+
slc = np.s_[:-1]
1186+
for arr in self.values():
1187+
res = trim_zeros(arr, trim='b')
1188+
assert_array_equal(res, arr[slc])
11831189

11841190
def test_trailing_skip(self):
1185-
a = np.array([0, 0, 1, 0, 2, 3, 0, 4, 0])
1186-
res = trim_zeros(a)
1187-
assert_array_equal(res, np.array([1, 0, 2, 3, 0, 4]))
1191+
slc = np.s_[2:]
1192+
for arr in self.values():
1193+
res = trim_zeros(arr, trim='F')
1194+
assert_array_equal(res, arr[slc])
1195+
1196+
def test_all_zero(self):
1197+
for _arr in self.values():
1198+
arr = np.zeros_like(_arr, dtype=_arr.dtype)
1199+
1200+
res1 = trim_zeros(arr, trim='B')
1201+
assert len(res1) == 0
1202+
1203+
res2 = trim_zeros(arr, trim='f')
1204+
assert len(res2) == 0
1205+
1206+
def test_size_zero(self):
1207+
arr = np.zeros(0)
1208+
res = trim_zeros(arr)
1209+
assert_array_equal(arr, res)
11881210

11891211

11901212
class TestExtins:

0 commit comments

Comments
 (0)