Skip to content

Commit fbefd42

Browse files
authored
Merge pull request #25370 from asmeurer/array_api-portability
ENH: Make `numpy.array_api` more portable Original NumPy Commit: ad360324dbb3f43f0b40384e97e2f251b4140c2f
2 parents 44579a1 + 90a26ab commit fbefd42

14 files changed

+222
-125
lines changed

array_api_strict/_array_object.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from enum import IntEnum
2020
from ._creation_functions import asarray
2121
from ._dtypes import (
22+
_DType,
2223
_all_dtypes,
2324
_boolean_dtypes,
2425
_integer_dtypes,
@@ -39,6 +40,13 @@
3940

4041
import numpy as np
4142

43+
# Placeholder object to represent the "cpu" device (the only device NumPy
44+
# supports).
45+
class _cpu_device:
46+
def __repr__(self):
47+
return "CPU_DEVICE"
48+
49+
CPU_DEVICE = _cpu_device()
4250

4351
class Array:
4452
"""
@@ -75,11 +83,13 @@ def _new(cls, x, /):
7583
if isinstance(x, np.generic):
7684
# Convert the array scalar to a 0-D array
7785
x = np.asarray(x)
78-
if x.dtype not in _all_dtypes:
86+
_dtype = _DType(x.dtype)
87+
if _dtype not in _all_dtypes:
7988
raise TypeError(
8089
f"The array_api namespace does not support the dtype '{x.dtype}'"
8190
)
8291
obj._array = x
92+
obj._dtype = _dtype
8393
return obj
8494

8595
# Prevent Array() from working
@@ -101,7 +111,7 @@ def __repr__(self: Array, /) -> str:
101111
"""
102112
Performs the operation __repr__.
103113
"""
104-
suffix = f", dtype={self.dtype.name})"
114+
suffix = f", dtype={self.dtype})"
105115
if 0 in self.shape:
106116
prefix = "empty("
107117
mid = str(self.shape)
@@ -176,6 +186,8 @@ def _promote_scalar(self, scalar):
176186
integer that is too large to fit in a NumPy integer dtype, or
177187
TypeError when the scalar type is incompatible with the dtype of self.
178188
"""
189+
from ._data_type_functions import iinfo
190+
179191
# Note: Only Python scalar types that match the array dtype are
180192
# allowed.
181193
if isinstance(scalar, bool):
@@ -189,7 +201,7 @@ def _promote_scalar(self, scalar):
189201
"Python int scalars cannot be promoted with bool arrays"
190202
)
191203
if self.dtype in _integer_dtypes:
192-
info = np.iinfo(self.dtype)
204+
info = iinfo(self.dtype)
193205
if not (info.min <= scalar <= info.max):
194206
raise OverflowError(
195207
"Python int scalars must be within the bounds of the dtype for integer arrays"
@@ -215,7 +227,7 @@ def _promote_scalar(self, scalar):
215227
# behavior for integers within the bounds of the integer dtype.
216228
# Outside of those bounds we use the default NumPy behavior (either
217229
# cast or raise OverflowError).
218-
return Array._new(np.array(scalar, self.dtype))
230+
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype))
219231

220232
@staticmethod
221233
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:
@@ -325,7 +337,9 @@ def _validate_index(self, key):
325337
for i in _key:
326338
if i is not None:
327339
nonexpanding_key.append(i)
328-
if isinstance(i, Array) or isinstance(i, np.ndarray):
340+
if isinstance(i, np.ndarray):
341+
raise IndexError("Index arrays for np.array_api must be np.array_api arrays")
342+
if isinstance(i, Array):
329343
if i.dtype in _boolean_dtypes:
330344
key_has_mask = True
331345
single_axes.append(i)
@@ -1067,7 +1081,7 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
10671081
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
10681082
if stream is not None:
10691083
raise ValueError("The stream argument to to_device() is not supported")
1070-
if device == 'cpu':
1084+
if device == CPU_DEVICE:
10711085
return self
10721086
raise ValueError(f"Unsupported device {device!r}")
10731087

@@ -1078,11 +1092,11 @@ def dtype(self) -> Dtype:
10781092
10791093
See its docstring for more information.
10801094
"""
1081-
return self._array.dtype
1095+
return self._dtype
10821096

10831097
@property
10841098
def device(self) -> Device:
1085-
return "cpu"
1099+
return CPU_DEVICE
10861100

10871101
# Note: mT is new in array API spec (see matrix_transpose)
10881102
@property

array_api_strict/_creation_functions.py

+56-36
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,15 @@
1212
SupportsBufferProtocol,
1313
)
1414
from collections.abc import Sequence
15-
from ._dtypes import _all_dtypes
15+
from ._dtypes import _DType, _all_dtypes
1616

1717
import numpy as np
1818

1919

2020
def _check_valid_dtype(dtype):
2121
# Note: Only spelling dtypes as the dtype objects is supported.
22-
23-
# We use this instead of "dtype in _all_dtypes" because the dtype objects
24-
# define equality with the sorts of things we want to disallow.
25-
for d in (None,) + _all_dtypes:
26-
if dtype is d:
27-
return
28-
raise ValueError("dtype must be one of the supported dtypes")
22+
if dtype not in (None,) + _all_dtypes:
23+
raise ValueError("dtype must be one of the supported dtypes")
2924

3025

3126
def asarray(
@@ -50,10 +45,13 @@ def asarray(
5045
"""
5146
# _array_object imports in this file are inside the functions to avoid
5247
# circular imports
53-
from ._array_object import Array
48+
from ._array_object import Array, CPU_DEVICE
5449

5550
_check_valid_dtype(dtype)
56-
if device not in ["cpu", None]:
51+
_np_dtype = None
52+
if dtype is not None:
53+
_np_dtype = dtype._np_dtype
54+
if device not in [CPU_DEVICE, None]:
5755
raise ValueError(f"Unsupported device {device!r}")
5856
if copy in (False, np._CopyMode.IF_NEEDED):
5957
# Note: copy=False is not yet implemented in np.asarray
@@ -62,13 +60,13 @@ def asarray(
6260
if dtype is not None and obj.dtype != dtype:
6361
copy = True
6462
if copy in (True, np._CopyMode.ALWAYS):
65-
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
63+
return Array._new(np.array(obj._array, copy=True, dtype=_np_dtype))
6664
return obj
6765
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
6866
# Give a better error message in this case. NumPy would convert this
6967
# to an object array. TODO: This won't handle large integers in lists.
7068
raise OverflowError("Integer out of bounds for array dtypes")
71-
res = np.asarray(obj, dtype=dtype)
69+
res = np.asarray(obj, dtype=_np_dtype)
7270
return Array._new(res)
7371

7472

@@ -86,11 +84,13 @@ def arange(
8684
8785
See its docstring for more information.
8886
"""
89-
from ._array_object import Array
87+
from ._array_object import Array, CPU_DEVICE
9088

9189
_check_valid_dtype(dtype)
92-
if device not in ["cpu", None]:
90+
if device not in [CPU_DEVICE, None]:
9391
raise ValueError(f"Unsupported device {device!r}")
92+
if dtype is not None:
93+
dtype = dtype._np_dtype
9494
return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
9595

9696

@@ -105,11 +105,13 @@ def empty(
105105
106106
See its docstring for more information.
107107
"""
108-
from ._array_object import Array
108+
from ._array_object import Array, CPU_DEVICE
109109

110110
_check_valid_dtype(dtype)
111-
if device not in ["cpu", None]:
111+
if device not in [CPU_DEVICE, None]:
112112
raise ValueError(f"Unsupported device {device!r}")
113+
if dtype is not None:
114+
dtype = dtype._np_dtype
113115
return Array._new(np.empty(shape, dtype=dtype))
114116

115117

@@ -121,11 +123,13 @@ def empty_like(
121123
122124
See its docstring for more information.
123125
"""
124-
from ._array_object import Array
126+
from ._array_object import Array, CPU_DEVICE
125127

126128
_check_valid_dtype(dtype)
127-
if device not in ["cpu", None]:
129+
if device not in [CPU_DEVICE, None]:
128130
raise ValueError(f"Unsupported device {device!r}")
131+
if dtype is not None:
132+
dtype = dtype._np_dtype
129133
return Array._new(np.empty_like(x._array, dtype=dtype))
130134

131135

@@ -143,11 +147,13 @@ def eye(
143147
144148
See its docstring for more information.
145149
"""
146-
from ._array_object import Array
150+
from ._array_object import Array, CPU_DEVICE
147151

148152
_check_valid_dtype(dtype)
149-
if device not in ["cpu", None]:
153+
if device not in [CPU_DEVICE, None]:
150154
raise ValueError(f"Unsupported device {device!r}")
155+
if dtype is not None:
156+
dtype = dtype._np_dtype
151157
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
152158

153159

@@ -169,15 +175,17 @@ def full(
169175
170176
See its docstring for more information.
171177
"""
172-
from ._array_object import Array
178+
from ._array_object import Array, CPU_DEVICE
173179

174180
_check_valid_dtype(dtype)
175-
if device not in ["cpu", None]:
181+
if device not in [CPU_DEVICE, None]:
176182
raise ValueError(f"Unsupported device {device!r}")
177183
if isinstance(fill_value, Array) and fill_value.ndim == 0:
178184
fill_value = fill_value._array
185+
if dtype is not None:
186+
dtype = dtype._np_dtype
179187
res = np.full(shape, fill_value, dtype=dtype)
180-
if res.dtype not in _all_dtypes:
188+
if _DType(res.dtype) not in _all_dtypes:
181189
# This will happen if the fill value is not something that NumPy
182190
# coerces to one of the acceptable dtypes.
183191
raise TypeError("Invalid input to full")
@@ -197,13 +205,15 @@ def full_like(
197205
198206
See its docstring for more information.
199207
"""
200-
from ._array_object import Array
208+
from ._array_object import Array, CPU_DEVICE
201209

202210
_check_valid_dtype(dtype)
203-
if device not in ["cpu", None]:
211+
if device not in [CPU_DEVICE, None]:
204212
raise ValueError(f"Unsupported device {device!r}")
213+
if dtype is not None:
214+
dtype = dtype._np_dtype
205215
res = np.full_like(x._array, fill_value, dtype=dtype)
206-
if res.dtype not in _all_dtypes:
216+
if _DType(res.dtype) not in _all_dtypes:
207217
# This will happen if the fill value is not something that NumPy
208218
# coerces to one of the acceptable dtypes.
209219
raise TypeError("Invalid input to full_like")
@@ -225,11 +235,13 @@ def linspace(
225235
226236
See its docstring for more information.
227237
"""
228-
from ._array_object import Array
238+
from ._array_object import Array, CPU_DEVICE
229239

230240
_check_valid_dtype(dtype)
231-
if device not in ["cpu", None]:
241+
if device not in [CPU_DEVICE, None]:
232242
raise ValueError(f"Unsupported device {device!r}")
243+
if dtype is not None:
244+
dtype = dtype._np_dtype
233245
return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
234246

235247

@@ -264,11 +276,13 @@ def ones(
264276
265277
See its docstring for more information.
266278
"""
267-
from ._array_object import Array
279+
from ._array_object import Array, CPU_DEVICE
268280

269281
_check_valid_dtype(dtype)
270-
if device not in ["cpu", None]:
282+
if device not in [CPU_DEVICE, None]:
271283
raise ValueError(f"Unsupported device {device!r}")
284+
if dtype is not None:
285+
dtype = dtype._np_dtype
272286
return Array._new(np.ones(shape, dtype=dtype))
273287

274288

@@ -280,11 +294,13 @@ def ones_like(
280294
281295
See its docstring for more information.
282296
"""
283-
from ._array_object import Array
297+
from ._array_object import Array, CPU_DEVICE
284298

285299
_check_valid_dtype(dtype)
286-
if device not in ["cpu", None]:
300+
if device not in [CPU_DEVICE, None]:
287301
raise ValueError(f"Unsupported device {device!r}")
302+
if dtype is not None:
303+
dtype = dtype._np_dtype
288304
return Array._new(np.ones_like(x._array, dtype=dtype))
289305

290306

@@ -327,11 +343,13 @@ def zeros(
327343
328344
See its docstring for more information.
329345
"""
330-
from ._array_object import Array
346+
from ._array_object import Array, CPU_DEVICE
331347

332348
_check_valid_dtype(dtype)
333-
if device not in ["cpu", None]:
349+
if device not in [CPU_DEVICE, None]:
334350
raise ValueError(f"Unsupported device {device!r}")
351+
if dtype is not None:
352+
dtype = dtype._np_dtype
335353
return Array._new(np.zeros(shape, dtype=dtype))
336354

337355

@@ -343,9 +361,11 @@ def zeros_like(
343361
344362
See its docstring for more information.
345363
"""
346-
from ._array_object import Array
364+
from ._array_object import Array, CPU_DEVICE
347365

348366
_check_valid_dtype(dtype)
349-
if device not in ["cpu", None]:
367+
if device not in [CPU_DEVICE, None]:
350368
raise ValueError(f"Unsupported device {device!r}")
369+
if dtype is not None:
370+
dtype = dtype._np_dtype
351371
return Array._new(np.zeros_like(x._array, dtype=dtype))

array_api_strict/_data_type_functions.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ._array_object import Array
44
from ._dtypes import (
5+
_DType,
56
_all_dtypes,
67
_boolean_dtypes,
78
_signed_integer_dtypes,
@@ -27,7 +28,7 @@
2728
def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
2829
if not copy and dtype == x.dtype:
2930
return x
30-
return Array._new(x._array.astype(dtype=dtype, copy=copy))
31+
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy))
3132

3233

3334
def broadcast_arrays(*arrays: Array) -> List[Array]:
@@ -107,6 +108,8 @@ def finfo(type: Union[Dtype, Array], /) -> finfo_object:
107108
108109
See its docstring for more information.
109110
"""
111+
if isinstance(type, _DType):
112+
type = type._np_dtype
110113
fi = np.finfo(type)
111114
# Note: The types of the float data here are float, whereas in NumPy they
112115
# are scalars of the corresponding float dtype.
@@ -126,6 +129,8 @@ def iinfo(type: Union[Dtype, Array], /) -> iinfo_object:
126129
127130
See its docstring for more information.
128131
"""
132+
if isinstance(type, _DType):
133+
type = type._np_dtype
129134
ii = np.iinfo(type)
130135
return iinfo_object(ii.bits, ii.max, ii.min, ii.dtype)
131136

0 commit comments

Comments
 (0)