Skip to content

Commit 17e3ef9

Browse files
authored
Merge pull request numpy#18183 from touqir14/master
MAINT: Optimize numpy.count_nonzero for int types using SIMD operations
2 parents 7a18e4a + 85e2ce9 commit 17e3ef9

File tree

3 files changed

+183
-70
lines changed

3 files changed

+183
-70
lines changed

benchmarks/benchmarks/bench_core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class CountNonzero(Benchmark):
142142
params = [
143143
[1, 2, 3],
144144
[100, 10000, 1000000],
145-
[bool, int, str, object]
145+
[bool, np.int8, np.int16, np.int32, np.int64, str, object]
146146
]
147147

148148
def setup(self, numaxes, size, dtype):

numpy/core/src/multiarray/item_selection.c

+168-65
Original file line numberDiff line numberDiff line change
@@ -2130,7 +2130,6 @@ count_nonzero_bytes_384(const npy_uint64 * w)
21302130
}
21312131

21322132
#if NPY_SIMD
2133-
21342133
/* Count the zero bytes between `*d` and `end`, updating `*d` to point to where to keep counting from. */
21352134
static NPY_INLINE NPY_GCC_OPT_3 npyv_u8
21362135
count_zero_bytes_u8(const npy_uint8 **d, const npy_uint8 *end, npy_uint8 max_count)
@@ -2166,18 +2165,18 @@ count_zero_bytes_u16(const npy_uint8 **d, const npy_uint8 *end, npy_uint16 max_c
21662165
}
21672166
return vsum16;
21682167
}
2169-
2168+
#endif // NPY_SIMD
21702169
/*
21712170
* Counts the number of non-zero values in a raw array.
21722171
* The one loop process is shown below(take SSE2 with 128bits vector for example):
2173-
* |------------16 lanes---------|
2172+
* |------------16 lanes---------|
21742173
*[vsum8] 255 255 255 ... 255 255 255 255 count_zero_bytes_u8: counting 255*16 elements
21752174
* !!
2176-
* |------------8 lanes---------|
2175+
* |------------8 lanes---------|
21772176
*[vsum16] 65535 65535 65535 ... 65535 count_zero_bytes_u16: counting (2*16-1)*16 elements
21782177
* 65535 65535 65535 ... 65535
21792178
* !!
2180-
* |------------4 lanes---------|
2179+
* |------------4 lanes---------|
21812180
*[sum_32_0] 65535 65535 65535 65535 count_nonzero_bytes
21822181
* 65535 65535 65535 65535
21832182
*[sum_32_1] 65535 65535 65535 65535
@@ -2186,40 +2185,143 @@ count_zero_bytes_u16(const npy_uint8 **d, const npy_uint8 *end, npy_uint16 max_c
21862185
* (2*16-1)*16
21872186
*/
21882187
static NPY_INLINE NPY_GCC_OPT_3 npy_intp
2189-
count_nonzero_bytes(const npy_uint8 *d, npy_uintp unrollx)
2188+
count_nonzero_u8(const char *data, npy_intp bstride, npy_uintp len)
21902189
{
2191-
npy_intp zero_count = 0;
2192-
const npy_uint8 *end = d + unrollx;
2193-
while (d < end) {
2194-
npyv_u16x2 vsum16 = count_zero_bytes_u16(&d, end, NPY_MAX_UINT16);
2195-
npyv_u32x2 sum_32_0 = npyv_expand_u32_u16(vsum16.val[0]);
2196-
npyv_u32x2 sum_32_1 = npyv_expand_u32_u16(vsum16.val[1]);
2197-
zero_count += npyv_sum_u32(npyv_add_u32(
2198-
npyv_add_u32(sum_32_0.val[0], sum_32_0.val[1]),
2199-
npyv_add_u32(sum_32_1.val[0], sum_32_1.val[1])
2200-
));
2201-
}
2202-
return unrollx - zero_count;
2190+
npy_intp count = 0;
2191+
if (bstride == 1) {
2192+
#if NPY_SIMD
2193+
npy_uintp len_m = len & -npyv_nlanes_u8;
2194+
npy_uintp zcount = 0;
2195+
for (const char *end = data + len_m; data < end;) {
2196+
npyv_u16x2 vsum16 = count_zero_bytes_u16((const npy_uint8**)&data, (const npy_uint8*)end, NPY_MAX_UINT16);
2197+
npyv_u32x2 sum_32_0 = npyv_expand_u32_u16(vsum16.val[0]);
2198+
npyv_u32x2 sum_32_1 = npyv_expand_u32_u16(vsum16.val[1]);
2199+
zcount += npyv_sum_u32(npyv_add_u32(
2200+
npyv_add_u32(sum_32_0.val[0], sum_32_0.val[1]),
2201+
npyv_add_u32(sum_32_1.val[0], sum_32_1.val[1])
2202+
));
2203+
}
2204+
len -= len_m;
2205+
count = len_m - zcount;
2206+
#else
2207+
if (!NPY_ALIGNMENT_REQUIRED || npy_is_aligned(data, sizeof(npy_uint64))) {
2208+
int step = 6 * sizeof(npy_uint64);
2209+
int left_bytes = len % step;
2210+
for (const char *end = data + len; data < end - left_bytes; data += step) {
2211+
count += count_nonzero_bytes_384((const npy_uint64 *)data);
2212+
}
2213+
len = left_bytes;
2214+
}
2215+
#endif // NPY_SIMD
2216+
}
2217+
for (; len > 0; --len, data += bstride) {
2218+
count += (*data != 0);
2219+
}
2220+
return count;
22032221
}
22042222

2223+
static NPY_INLINE NPY_GCC_OPT_3 npy_intp
2224+
count_nonzero_u16(const char *data, npy_intp bstride, npy_uintp len)
2225+
{
2226+
npy_intp count = 0;
2227+
#if NPY_SIMD
2228+
if (bstride == sizeof(npy_uint16)) {
2229+
npy_uintp zcount = 0, len_m = len & -npyv_nlanes_u16;
2230+
const npyv_u16 vone = npyv_setall_u16(1);
2231+
const npyv_u16 vzero = npyv_zero_u16();
2232+
2233+
for (npy_uintp lenx = len_m; lenx > 0;) {
2234+
npyv_u16 vsum16 = npyv_zero_u16();
2235+
npy_uintp max16 = PyArray_MIN(lenx, NPY_MAX_UINT16*npyv_nlanes_u16);
2236+
2237+
for (const char *end = data + max16*bstride; data < end; data += NPY_SIMD_WIDTH) {
2238+
npyv_u16 mask = npyv_cvt_u16_b16(npyv_cmpeq_u16(npyv_load_u16((npy_uint16*)data), vzero));
2239+
mask = npyv_and_u16(mask, vone);
2240+
vsum16 = npyv_add_u16(vsum16, mask);
2241+
}
2242+
lenx -= max16;
2243+
zcount += npyv_sumup_u16(vsum16);
2244+
}
2245+
len -= len_m;
2246+
count = len_m - zcount;
2247+
}
2248+
#endif
2249+
for (; len > 0; --len, data += bstride) {
2250+
count += (*(npy_uint16*)data != 0);
2251+
}
2252+
return count;
2253+
}
2254+
2255+
static NPY_INLINE NPY_GCC_OPT_3 npy_intp
2256+
count_nonzero_u32(const char *data, npy_intp bstride, npy_uintp len)
2257+
{
2258+
npy_intp count = 0;
2259+
#if NPY_SIMD
2260+
if (bstride == sizeof(npy_uint32)) {
2261+
const npy_uintp max_iter = NPY_MAX_UINT32*npyv_nlanes_u32;
2262+
const npy_uintp len_m = (len > max_iter ? max_iter : len) & -npyv_nlanes_u32;
2263+
const npyv_u32 vone = npyv_setall_u32(1);
2264+
const npyv_u32 vzero = npyv_zero_u32();
2265+
2266+
npyv_u32 vsum32 = npyv_zero_u32();
2267+
for (const char *end = data + len_m*bstride; data < end; data += NPY_SIMD_WIDTH) {
2268+
npyv_u32 mask = npyv_cvt_u32_b32(npyv_cmpeq_u32(npyv_load_u32((npy_uint32*)data), vzero));
2269+
mask = npyv_and_u32(mask, vone);
2270+
vsum32 = npyv_add_u32(vsum32, mask);
2271+
}
2272+
const npyv_u32 maskevn = npyv_reinterpret_u32_u64(npyv_setall_u64(0xffffffffULL));
2273+
npyv_u64 odd = npyv_shri_u64(npyv_reinterpret_u64_u32(vsum32), 32);
2274+
npyv_u64 even = npyv_reinterpret_u64_u32(npyv_and_u32(vsum32, maskevn));
2275+
count = len_m - npyv_sum_u64(npyv_add_u64(odd, even));
2276+
len -= len_m;
2277+
}
2278+
#endif
2279+
for (; len > 0; --len, data += bstride) {
2280+
count += (*(npy_uint32*)data != 0);
2281+
}
2282+
return count;
2283+
}
2284+
2285+
static NPY_INLINE NPY_GCC_OPT_3 npy_intp
2286+
count_nonzero_u64(const char *data, npy_intp bstride, npy_uintp len)
2287+
{
2288+
npy_intp count = 0;
2289+
#if NPY_SIMD
2290+
if (bstride == sizeof(npy_uint64)) {
2291+
const npy_uintp len_m = len & -npyv_nlanes_u64;
2292+
const npyv_u64 vone = npyv_setall_u64(1);
2293+
const npyv_u64 vzero = npyv_zero_u64();
2294+
2295+
npyv_u64 vsum64 = npyv_zero_u64();
2296+
for (const char *end = data + len_m*bstride; data < end; data += NPY_SIMD_WIDTH) {
2297+
npyv_u64 mask = npyv_cvt_u64_b64(npyv_cmpeq_u64(npyv_load_u64((npy_uint64*)data), vzero));
2298+
mask = npyv_and_u64(mask, vone);
2299+
vsum64 = npyv_add_u64(vsum64, mask);
2300+
}
2301+
len -= len_m;
2302+
count = len_m - npyv_sum_u64(vsum64);
2303+
}
22052304
#endif
2305+
for (; len > 0; --len, data += bstride) {
2306+
count += (*(npy_uint64*)data != 0);
2307+
}
2308+
return count;
2309+
}
22062310
/*
22072311
* Counts the number of True values in a raw boolean array. This
22082312
* is a low-overhead function which does no heap allocations.
22092313
*
22102314
* Returns -1 on error.
22112315
*/
2212-
NPY_NO_EXPORT npy_intp
2213-
count_boolean_trues(int ndim, char *data, npy_intp const *ashape, npy_intp const *astrides)
2316+
static NPY_GCC_OPT_3 npy_intp
2317+
count_nonzero_int(int ndim, char *data, const npy_intp *ashape, const npy_intp *astrides, int elsize)
22142318
{
2215-
2319+
assert(elsize <= 8);
22162320
int idim;
22172321
npy_intp shape[NPY_MAXDIMS], strides[NPY_MAXDIMS];
2218-
npy_intp i, coord[NPY_MAXDIMS];
2219-
npy_intp count = 0;
2220-
NPY_BEGIN_THREADS_DEF;
2322+
npy_intp coord[NPY_MAXDIMS];
22212323

2222-
/* Use raw iteration with no heap memory allocation */
2324+
// Use raw iteration with no heap memory allocation
22232325
if (PyArray_PrepareOneRawArrayIter(
22242326
ndim, ashape,
22252327
data, astrides,
@@ -2228,51 +2330,44 @@ count_boolean_trues(int ndim, char *data, npy_intp const *ashape, npy_intp const
22282330
return -1;
22292331
}
22302332

2231-
/* Handle zero-sized array */
2333+
// Handle zero-sized array
22322334
if (shape[0] == 0) {
22332335
return 0;
22342336
}
22352337

2338+
NPY_BEGIN_THREADS_DEF;
22362339
NPY_BEGIN_THREADS_THRESHOLDED(shape[0]);
2237-
/* Special case for contiguous inner loop */
2238-
if (strides[0] == 1) {
2239-
NPY_RAW_ITER_START(idim, ndim, coord, shape) {
2240-
/* Process the innermost dimension */
2241-
const char *d = data;
2242-
const char *e = data + shape[0];
2243-
#if NPY_SIMD
2244-
npy_uintp stride = shape[0] & -npyv_nlanes_u8;
2245-
count += count_nonzero_bytes((const npy_uint8 *)d, stride);
2246-
d += stride;
2247-
#else
2248-
if (!NPY_ALIGNMENT_REQUIRED ||
2249-
npy_is_aligned(d, sizeof(npy_uint64))) {
2250-
npy_uintp stride = 6 * sizeof(npy_uint64);
2251-
for (; d < e - (shape[0] % stride); d += stride) {
2252-
count += count_nonzero_bytes_384((const npy_uint64 *)d);
2253-
}
2254-
}
2255-
#endif
2256-
for (; d < e; ++d) {
2257-
count += (*d != 0);
2258-
}
2259-
} NPY_RAW_ITER_ONE_NEXT(idim, ndim, coord, shape, data, strides);
2260-
}
2261-
/* General inner loop */
2262-
else {
2263-
NPY_RAW_ITER_START(idim, ndim, coord, shape) {
2264-
char *d = data;
2265-
/* Process the innermost dimension */
2266-
for (i = 0; i < shape[0]; ++i, d += strides[0]) {
2267-
count += (*d != 0);
2268-
}
2269-
} NPY_RAW_ITER_ONE_NEXT(idim, ndim, coord, shape, data, strides);
2340+
2341+
#define NONZERO_CASE(LEN, SFX) \
2342+
case LEN: \
2343+
NPY_RAW_ITER_START(idim, ndim, coord, shape) { \
2344+
count += count_nonzero_##SFX(data, strides[0], shape[0]); \
2345+
} NPY_RAW_ITER_ONE_NEXT(idim, ndim, coord, shape, data, strides); \
2346+
break
2347+
2348+
npy_intp count = 0;
2349+
switch(elsize) {
2350+
NONZERO_CASE(1, u8);
2351+
NONZERO_CASE(2, u16);
2352+
NONZERO_CASE(4, u32);
2353+
NONZERO_CASE(8, u64);
22702354
}
2355+
#undef NONZERO_CASE
22712356

22722357
NPY_END_THREADS;
2273-
22742358
return count;
22752359
}
2360+
/*
2361+
* Counts the number of True values in a raw boolean array. This
2362+
* is a low-overhead function which does no heap allocations.
2363+
*
2364+
* Returns -1 on error.
2365+
*/
2366+
NPY_NO_EXPORT NPY_GCC_OPT_3 npy_intp
2367+
count_boolean_trues(int ndim, char *data, npy_intp const *ashape, npy_intp const *astrides)
2368+
{
2369+
return count_nonzero_int(ndim, data, ashape, astrides, 1);
2370+
}
22762371

22772372
/*NUMPY_API
22782373
* Counts the number of non-zero elements in the array.
@@ -2295,14 +2390,22 @@ PyArray_CountNonzero(PyArrayObject *self)
22952390
npy_intp *strideptr, *innersizeptr;
22962391
NPY_BEGIN_THREADS_DEF;
22972392

2298-
/* Special low-overhead version specific to the boolean type */
2393+
// Special low-overhead version specific to the boolean/int types
22992394
dtype = PyArray_DESCR(self);
2300-
if (dtype->type_num == NPY_BOOL) {
2301-
return count_boolean_trues(PyArray_NDIM(self), PyArray_DATA(self),
2302-
PyArray_DIMS(self), PyArray_STRIDES(self));
2395+
switch(dtype->kind) {
2396+
case 'u':
2397+
case 'i':
2398+
case 'b':
2399+
if (dtype->elsize > 8) {
2400+
break;
2401+
}
2402+
return count_nonzero_int(
2403+
PyArray_NDIM(self), PyArray_BYTES(self), PyArray_DIMS(self),
2404+
PyArray_STRIDES(self), dtype->elsize
2405+
);
23032406
}
2304-
nonzero = PyArray_DESCR(self)->f->nonzero;
23052407

2408+
nonzero = PyArray_DESCR(self)->f->nonzero;
23062409
/* If it's a trivial one-dimensional loop, don't use an iterator */
23072410
if (PyArray_TRIVIALLY_ITERABLE(self)) {
23082411
needs_api = PyDataType_FLAGCHK(dtype, NPY_NEEDS_PYAPI);

numpy/core/tests/test_numeric.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1266,20 +1266,30 @@ def test_nonzero_onedim(self):
12661266
assert_equal(np.count_nonzero(x), 4)
12671267
assert_equal(np.nonzero(x), ([0, 2, 3, 6],))
12681268

1269-
x = np.array([(1, 2), (0, 0), (1, 1), (-1, 3), (0, 7)],
1270-
dtype=[('a', 'i4'), ('b', 'i2')])
1269+
# x = np.array([(1, 2), (0, 0), (1, 1), (-1, 3), (0, 7)],
1270+
# dtype=[('a', 'i4'), ('b', 'i2')])
1271+
x = np.array([(1, 2, -5, -3), (0, 0, 2, 7), (1, 1, 0, 1), (-1, 3, 1, 0), (0, 7, 0, 4)],
1272+
dtype=[('a', 'i4'), ('b', 'i2'), ('c', 'i1'), ('d', 'i8')])
12711273
assert_equal(np.count_nonzero(x['a']), 3)
12721274
assert_equal(np.count_nonzero(x['b']), 4)
1275+
assert_equal(np.count_nonzero(x['c']), 3)
1276+
assert_equal(np.count_nonzero(x['d']), 4)
12731277
assert_equal(np.nonzero(x['a']), ([0, 2, 3],))
12741278
assert_equal(np.nonzero(x['b']), ([0, 2, 3, 4],))
12751279

12761280
def test_nonzero_twodim(self):
12771281
x = np.array([[0, 1, 0], [2, 0, 3]])
1278-
assert_equal(np.count_nonzero(x), 3)
1282+
assert_equal(np.count_nonzero(x.astype('i1')), 3)
1283+
assert_equal(np.count_nonzero(x.astype('i2')), 3)
1284+
assert_equal(np.count_nonzero(x.astype('i4')), 3)
1285+
assert_equal(np.count_nonzero(x.astype('i8')), 3)
12791286
assert_equal(np.nonzero(x), ([0, 1, 1], [1, 0, 2]))
12801287

12811288
x = np.eye(3)
1282-
assert_equal(np.count_nonzero(x), 3)
1289+
assert_equal(np.count_nonzero(x.astype('i1')), 3)
1290+
assert_equal(np.count_nonzero(x.astype('i2')), 3)
1291+
assert_equal(np.count_nonzero(x.astype('i4')), 3)
1292+
assert_equal(np.count_nonzero(x.astype('i8')), 3)
12831293
assert_equal(np.nonzero(x), ([0, 1, 2], [0, 1, 2]))
12841294

12851295
x = np.array([[(0, 1), (0, 0), (1, 11)],

0 commit comments

Comments
 (0)