@@ -2130,7 +2130,6 @@ count_nonzero_bytes_384(const npy_uint64 * w)
2130
2130
}
2131
2131
2132
2132
#if NPY_SIMD
2133
-
2134
2133
/* Count the zero bytes between `*d` and `end`, updating `*d` to point to where to keep counting from. */
2135
2134
static NPY_INLINE NPY_GCC_OPT_3 npyv_u8
2136
2135
count_zero_bytes_u8 (const npy_uint8 * * d , const npy_uint8 * end , npy_uint8 max_count )
@@ -2166,18 +2165,18 @@ count_zero_bytes_u16(const npy_uint8 **d, const npy_uint8 *end, npy_uint16 max_c
2166
2165
}
2167
2166
return vsum16 ;
2168
2167
}
2169
-
2168
+ #endif // NPY_SIMD
2170
2169
/*
2171
2170
* Counts the number of non-zero values in a raw array.
2172
2171
* The one loop process is shown below(take SSE2 with 128bits vector for example):
2173
- * |------------16 lanes---------|
2172
+ * |------------16 lanes---------|
2174
2173
*[vsum8] 255 255 255 ... 255 255 255 255 count_zero_bytes_u8: counting 255*16 elements
2175
2174
* !!
2176
- * |------------8 lanes---------|
2175
+ * |------------8 lanes---------|
2177
2176
*[vsum16] 65535 65535 65535 ... 65535 count_zero_bytes_u16: counting (2*16-1)*16 elements
2178
2177
* 65535 65535 65535 ... 65535
2179
2178
* !!
2180
- * |------------4 lanes---------|
2179
+ * |------------4 lanes---------|
2181
2180
*[sum_32_0] 65535 65535 65535 65535 count_nonzero_bytes
2182
2181
* 65535 65535 65535 65535
2183
2182
*[sum_32_1] 65535 65535 65535 65535
@@ -2186,40 +2185,143 @@ count_zero_bytes_u16(const npy_uint8 **d, const npy_uint8 *end, npy_uint16 max_c
2186
2185
* (2*16-1)*16
2187
2186
*/
2188
2187
static NPY_INLINE NPY_GCC_OPT_3 npy_intp
2189
- count_nonzero_bytes (const npy_uint8 * d , npy_uintp unrollx )
2188
+ count_nonzero_u8 (const char * data , npy_intp bstride , npy_uintp len )
2190
2189
{
2191
- npy_intp zero_count = 0 ;
2192
- const npy_uint8 * end = d + unrollx ;
2193
- while (d < end ) {
2194
- npyv_u16x2 vsum16 = count_zero_bytes_u16 (& d , end , NPY_MAX_UINT16 );
2195
- npyv_u32x2 sum_32_0 = npyv_expand_u32_u16 (vsum16 .val [0 ]);
2196
- npyv_u32x2 sum_32_1 = npyv_expand_u32_u16 (vsum16 .val [1 ]);
2197
- zero_count += npyv_sum_u32 (npyv_add_u32 (
2198
- npyv_add_u32 (sum_32_0 .val [0 ], sum_32_0 .val [1 ]),
2199
- npyv_add_u32 (sum_32_1 .val [0 ], sum_32_1 .val [1 ])
2200
- ));
2201
- }
2202
- return unrollx - zero_count ;
2190
+ npy_intp count = 0 ;
2191
+ if (bstride == 1 ) {
2192
+ #if NPY_SIMD
2193
+ npy_uintp len_m = len & - npyv_nlanes_u8 ;
2194
+ npy_uintp zcount = 0 ;
2195
+ for (const char * end = data + len_m ; data < end ;) {
2196
+ npyv_u16x2 vsum16 = count_zero_bytes_u16 ((const npy_uint8 * * )& data , (const npy_uint8 * )end , NPY_MAX_UINT16 );
2197
+ npyv_u32x2 sum_32_0 = npyv_expand_u32_u16 (vsum16 .val [0 ]);
2198
+ npyv_u32x2 sum_32_1 = npyv_expand_u32_u16 (vsum16 .val [1 ]);
2199
+ zcount += npyv_sum_u32 (npyv_add_u32 (
2200
+ npyv_add_u32 (sum_32_0 .val [0 ], sum_32_0 .val [1 ]),
2201
+ npyv_add_u32 (sum_32_1 .val [0 ], sum_32_1 .val [1 ])
2202
+ ));
2203
+ }
2204
+ len -= len_m ;
2205
+ count = len_m - zcount ;
2206
+ #else
2207
+ if (!NPY_ALIGNMENT_REQUIRED || npy_is_aligned (data , sizeof (npy_uint64 ))) {
2208
+ int step = 6 * sizeof (npy_uint64 );
2209
+ int left_bytes = len % step ;
2210
+ for (const char * end = data + len ; data < end - left_bytes ; data += step ) {
2211
+ count += count_nonzero_bytes_384 ((const npy_uint64 * )data );
2212
+ }
2213
+ len = left_bytes ;
2214
+ }
2215
+ #endif // NPY_SIMD
2216
+ }
2217
+ for (; len > 0 ; -- len , data += bstride ) {
2218
+ count += (* data != 0 );
2219
+ }
2220
+ return count ;
2203
2221
}
2204
2222
2223
+ static NPY_INLINE NPY_GCC_OPT_3 npy_intp
2224
+ count_nonzero_u16 (const char * data , npy_intp bstride , npy_uintp len )
2225
+ {
2226
+ npy_intp count = 0 ;
2227
+ #if NPY_SIMD
2228
+ if (bstride == sizeof (npy_uint16 )) {
2229
+ npy_uintp zcount = 0 , len_m = len & - npyv_nlanes_u16 ;
2230
+ const npyv_u16 vone = npyv_setall_u16 (1 );
2231
+ const npyv_u16 vzero = npyv_zero_u16 ();
2232
+
2233
+ for (npy_uintp lenx = len_m ; lenx > 0 ;) {
2234
+ npyv_u16 vsum16 = npyv_zero_u16 ();
2235
+ npy_uintp max16 = PyArray_MIN (lenx , NPY_MAX_UINT16 * npyv_nlanes_u16 );
2236
+
2237
+ for (const char * end = data + max16 * bstride ; data < end ; data += NPY_SIMD_WIDTH ) {
2238
+ npyv_u16 mask = npyv_cvt_u16_b16 (npyv_cmpeq_u16 (npyv_load_u16 ((npy_uint16 * )data ), vzero ));
2239
+ mask = npyv_and_u16 (mask , vone );
2240
+ vsum16 = npyv_add_u16 (vsum16 , mask );
2241
+ }
2242
+ lenx -= max16 ;
2243
+ zcount += npyv_sumup_u16 (vsum16 );
2244
+ }
2245
+ len -= len_m ;
2246
+ count = len_m - zcount ;
2247
+ }
2248
+ #endif
2249
+ for (; len > 0 ; -- len , data += bstride ) {
2250
+ count += (* (npy_uint16 * )data != 0 );
2251
+ }
2252
+ return count ;
2253
+ }
2254
+
2255
+ static NPY_INLINE NPY_GCC_OPT_3 npy_intp
2256
+ count_nonzero_u32 (const char * data , npy_intp bstride , npy_uintp len )
2257
+ {
2258
+ npy_intp count = 0 ;
2259
+ #if NPY_SIMD
2260
+ if (bstride == sizeof (npy_uint32 )) {
2261
+ const npy_uintp max_iter = NPY_MAX_UINT32 * npyv_nlanes_u32 ;
2262
+ const npy_uintp len_m = (len > max_iter ? max_iter : len ) & - npyv_nlanes_u32 ;
2263
+ const npyv_u32 vone = npyv_setall_u32 (1 );
2264
+ const npyv_u32 vzero = npyv_zero_u32 ();
2265
+
2266
+ npyv_u32 vsum32 = npyv_zero_u32 ();
2267
+ for (const char * end = data + len_m * bstride ; data < end ; data += NPY_SIMD_WIDTH ) {
2268
+ npyv_u32 mask = npyv_cvt_u32_b32 (npyv_cmpeq_u32 (npyv_load_u32 ((npy_uint32 * )data ), vzero ));
2269
+ mask = npyv_and_u32 (mask , vone );
2270
+ vsum32 = npyv_add_u32 (vsum32 , mask );
2271
+ }
2272
+ const npyv_u32 maskevn = npyv_reinterpret_u32_u64 (npyv_setall_u64 (0xffffffffULL ));
2273
+ npyv_u64 odd = npyv_shri_u64 (npyv_reinterpret_u64_u32 (vsum32 ), 32 );
2274
+ npyv_u64 even = npyv_reinterpret_u64_u32 (npyv_and_u32 (vsum32 , maskevn ));
2275
+ count = len_m - npyv_sum_u64 (npyv_add_u64 (odd , even ));
2276
+ len -= len_m ;
2277
+ }
2278
+ #endif
2279
+ for (; len > 0 ; -- len , data += bstride ) {
2280
+ count += (* (npy_uint32 * )data != 0 );
2281
+ }
2282
+ return count ;
2283
+ }
2284
+
2285
+ static NPY_INLINE NPY_GCC_OPT_3 npy_intp
2286
+ count_nonzero_u64 (const char * data , npy_intp bstride , npy_uintp len )
2287
+ {
2288
+ npy_intp count = 0 ;
2289
+ #if NPY_SIMD
2290
+ if (bstride == sizeof (npy_uint64 )) {
2291
+ const npy_uintp len_m = len & - npyv_nlanes_u64 ;
2292
+ const npyv_u64 vone = npyv_setall_u64 (1 );
2293
+ const npyv_u64 vzero = npyv_zero_u64 ();
2294
+
2295
+ npyv_u64 vsum64 = npyv_zero_u64 ();
2296
+ for (const char * end = data + len_m * bstride ; data < end ; data += NPY_SIMD_WIDTH ) {
2297
+ npyv_u64 mask = npyv_cvt_u64_b64 (npyv_cmpeq_u64 (npyv_load_u64 ((npy_uint64 * )data ), vzero ));
2298
+ mask = npyv_and_u64 (mask , vone );
2299
+ vsum64 = npyv_add_u64 (vsum64 , mask );
2300
+ }
2301
+ len -= len_m ;
2302
+ count = len_m - npyv_sum_u64 (vsum64 );
2303
+ }
2205
2304
#endif
2305
+ for (; len > 0 ; -- len , data += bstride ) {
2306
+ count += (* (npy_uint64 * )data != 0 );
2307
+ }
2308
+ return count ;
2309
+ }
2206
2310
/*
2207
2311
* Counts the number of True values in a raw boolean array. This
2208
2312
* is a low-overhead function which does no heap allocations.
2209
2313
*
2210
2314
* Returns -1 on error.
2211
2315
*/
2212
- NPY_NO_EXPORT npy_intp
2213
- count_boolean_trues (int ndim , char * data , npy_intp const * ashape , npy_intp const * astrides )
2316
+ static NPY_GCC_OPT_3 npy_intp
2317
+ count_nonzero_int (int ndim , char * data , const npy_intp * ashape , const npy_intp * astrides , int elsize )
2214
2318
{
2215
-
2319
+ assert ( elsize <= 8 );
2216
2320
int idim ;
2217
2321
npy_intp shape [NPY_MAXDIMS ], strides [NPY_MAXDIMS ];
2218
- npy_intp i , coord [NPY_MAXDIMS ];
2219
- npy_intp count = 0 ;
2220
- NPY_BEGIN_THREADS_DEF ;
2322
+ npy_intp coord [NPY_MAXDIMS ];
2221
2323
2222
- /* Use raw iteration with no heap memory allocation */
2324
+ // Use raw iteration with no heap memory allocation
2223
2325
if (PyArray_PrepareOneRawArrayIter (
2224
2326
ndim , ashape ,
2225
2327
data , astrides ,
@@ -2228,51 +2330,44 @@ count_boolean_trues(int ndim, char *data, npy_intp const *ashape, npy_intp const
2228
2330
return -1 ;
2229
2331
}
2230
2332
2231
- /* Handle zero-sized array */
2333
+ // Handle zero-sized array
2232
2334
if (shape [0 ] == 0 ) {
2233
2335
return 0 ;
2234
2336
}
2235
2337
2338
+ NPY_BEGIN_THREADS_DEF ;
2236
2339
NPY_BEGIN_THREADS_THRESHOLDED (shape [0 ]);
2237
- /* Special case for contiguous inner loop */
2238
- if (strides [0 ] == 1 ) {
2239
- NPY_RAW_ITER_START (idim , ndim , coord , shape ) {
2240
- /* Process the innermost dimension */
2241
- const char * d = data ;
2242
- const char * e = data + shape [0 ];
2243
- #if NPY_SIMD
2244
- npy_uintp stride = shape [0 ] & - npyv_nlanes_u8 ;
2245
- count += count_nonzero_bytes ((const npy_uint8 * )d , stride );
2246
- d += stride ;
2247
- #else
2248
- if (!NPY_ALIGNMENT_REQUIRED ||
2249
- npy_is_aligned (d , sizeof (npy_uint64 ))) {
2250
- npy_uintp stride = 6 * sizeof (npy_uint64 );
2251
- for (; d < e - (shape [0 ] % stride ); d += stride ) {
2252
- count += count_nonzero_bytes_384 ((const npy_uint64 * )d );
2253
- }
2254
- }
2255
- #endif
2256
- for (; d < e ; ++ d ) {
2257
- count += (* d != 0 );
2258
- }
2259
- } NPY_RAW_ITER_ONE_NEXT (idim , ndim , coord , shape , data , strides );
2260
- }
2261
- /* General inner loop */
2262
- else {
2263
- NPY_RAW_ITER_START (idim , ndim , coord , shape ) {
2264
- char * d = data ;
2265
- /* Process the innermost dimension */
2266
- for (i = 0 ; i < shape [0 ]; ++ i , d += strides [0 ]) {
2267
- count += (* d != 0 );
2268
- }
2269
- } NPY_RAW_ITER_ONE_NEXT (idim , ndim , coord , shape , data , strides );
2340
+
2341
+ #define NONZERO_CASE (LEN , SFX ) \
2342
+ case LEN: \
2343
+ NPY_RAW_ITER_START(idim, ndim, coord, shape) { \
2344
+ count += count_nonzero_##SFX(data, strides[0], shape[0]); \
2345
+ } NPY_RAW_ITER_ONE_NEXT(idim, ndim, coord, shape, data, strides); \
2346
+ break
2347
+
2348
+ npy_intp count = 0 ;
2349
+ switch (elsize ) {
2350
+ NONZERO_CASE (1 , u8 );
2351
+ NONZERO_CASE (2 , u16 );
2352
+ NONZERO_CASE (4 , u32 );
2353
+ NONZERO_CASE (8 , u64 );
2270
2354
}
2355
+ #undef NONZERO_CASE
2271
2356
2272
2357
NPY_END_THREADS ;
2273
-
2274
2358
return count ;
2275
2359
}
2360
+ /*
2361
+ * Counts the number of True values in a raw boolean array. This
2362
+ * is a low-overhead function which does no heap allocations.
2363
+ *
2364
+ * Returns -1 on error.
2365
+ */
2366
+ NPY_NO_EXPORT NPY_GCC_OPT_3 npy_intp
2367
+ count_boolean_trues (int ndim , char * data , npy_intp const * ashape , npy_intp const * astrides )
2368
+ {
2369
+ return count_nonzero_int (ndim , data , ashape , astrides , 1 );
2370
+ }
2276
2371
2277
2372
/*NUMPY_API
2278
2373
* Counts the number of non-zero elements in the array.
@@ -2295,14 +2390,22 @@ PyArray_CountNonzero(PyArrayObject *self)
2295
2390
npy_intp * strideptr , * innersizeptr ;
2296
2391
NPY_BEGIN_THREADS_DEF ;
2297
2392
2298
- /* Special low-overhead version specific to the boolean type */
2393
+ // Special low-overhead version specific to the boolean/int types
2299
2394
dtype = PyArray_DESCR (self );
2300
- if (dtype -> type_num == NPY_BOOL ) {
2301
- return count_boolean_trues (PyArray_NDIM (self ), PyArray_DATA (self ),
2302
- PyArray_DIMS (self ), PyArray_STRIDES (self ));
2395
+ switch (dtype -> kind ) {
2396
+ case 'u' :
2397
+ case 'i' :
2398
+ case 'b' :
2399
+ if (dtype -> elsize > 8 ) {
2400
+ break ;
2401
+ }
2402
+ return count_nonzero_int (
2403
+ PyArray_NDIM (self ), PyArray_BYTES (self ), PyArray_DIMS (self ),
2404
+ PyArray_STRIDES (self ), dtype -> elsize
2405
+ );
2303
2406
}
2304
- nonzero = PyArray_DESCR (self )-> f -> nonzero ;
2305
2407
2408
+ nonzero = PyArray_DESCR (self )-> f -> nonzero ;
2306
2409
/* If it's a trivial one-dimensional loop, don't use an iterator */
2307
2410
if (PyArray_TRIVIALLY_ITERABLE (self )) {
2308
2411
needs_api = PyDataType_FLAGCHK (dtype , NPY_NEEDS_PYAPI );
0 commit comments