Skip to content

Commit bdd4e2e

Browse files
authored
ENH: ARM Neon implementation with intrinsic for np.argmax. (numpy#16375)
* Neon implementation with intrinsic for bool argmax
1 parent 1b212bd commit bdd4e2e

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

benchmarks/benchmarks/bench_reduce.py

+9
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ def time_min(self, dtype):
5858
def time_max(self, dtype):
5959
np.max(self.d)
6060

61+
class ArgMax(Benchmark):
62+
params = [np.float32, bool]
63+
param_names = ['dtype']
64+
65+
def setup(self, dtype):
66+
self.d = np.zeros(200000, dtype=dtype)
67+
68+
def time_argmax(self, dtype):
69+
np.argmax(self.d)
6170

6271
class SmallReduction(Benchmark):
6372
def setup(self):

numpy/core/src/multiarray/arraytypes.c.src

+25-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#include "arrayobject.h"
2828
#include "alloc.h"
2929
#include "typeinfo.h"
30+
#if defined(__ARM_NEON__) || defined (__ARM_NEON)
31+
#include <arm_neon.h>
32+
#endif
3033
#ifdef NPY_HAVE_SSE2_INTRINSICS
3134
#include <emmintrin.h>
3235
#endif
@@ -3070,7 +3073,15 @@ finish:
30703073
** ARGFUNC **
30713074
*****************************************************************************
30723075
*/
3073-
3076+
#if defined(__ARM_NEON__) || defined (__ARM_NEON)
3077+
int32_t _mm_movemask_epi8_neon(uint8x16_t input)
3078+
{
3079+
int8x8_t m0 = vcreate_s8(0x0706050403020100ULL);
3080+
uint8x16_t v0 = vshlq_u8(vshrq_n_u8(input, 7), vcombine_s8(m0, m0));
3081+
uint64x2_t v1 = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v0)));
3082+
return (int)vgetq_lane_u64(v1, 0) + ((int)vgetq_lane_u64(v1, 1) << 8);
3083+
}
3084+
#endif
30743085
#define _LESS_THAN_OR_EQUAL(a,b) ((a) <= (b))
30753086

30763087
static int
@@ -3091,6 +3102,19 @@ BOOL_argmax(npy_bool *ip, npy_intp n, npy_intp *max_ind,
30913102
break;
30923103
}
30933104
}
3105+
#else
3106+
#if defined(__ARM_NEON__) || defined (__ARM_NEON)
3107+
uint8x16_t zero = vdupq_n_u8(0);
3108+
for(; i < n - (n % 32); i+=32) {
3109+
uint8x16_t d1 = vld1q_u8((char *)&ip[i]);
3110+
uint8x16_t d2 = vld1q_u8((char *)&ip[i + 16]);
3111+
d1 = vceqq_u8(d1, zero);
3112+
d2 = vceqq_u8(d2, zero);
3113+
if(_mm_movemask_epi8_neon(vminq_u8(d1, d2)) != 0xFFFF) {
3114+
break;
3115+
}
3116+
}
3117+
#endif
30943118
#endif
30953119
for (; i < n; i++) {
30963120
if (ip[i]) {

0 commit comments

Comments
 (0)