12
12
SupportsBufferProtocol ,
13
13
)
14
14
from collections .abc import Sequence
15
- from ._dtypes import _all_dtypes
15
+ from ._dtypes import _DType , _all_dtypes
16
16
17
17
import numpy as np
18
18
19
19
20
20
def _check_valid_dtype (dtype ):
21
21
# Note: Only spelling dtypes as the dtype objects is supported.
22
-
23
- # We use this instead of "dtype in _all_dtypes" because the dtype objects
24
- # define equality with the sorts of things we want to disallow.
25
- for d in (None ,) + _all_dtypes :
26
- if dtype is d :
27
- return
28
- raise ValueError ("dtype must be one of the supported dtypes" )
22
+ if dtype not in (None ,) + _all_dtypes :
23
+ raise ValueError ("dtype must be one of the supported dtypes" )
29
24
30
25
31
26
def asarray (
@@ -50,10 +45,13 @@ def asarray(
50
45
"""
51
46
# _array_object imports in this file are inside the functions to avoid
52
47
# circular imports
53
- from ._array_object import Array
48
+ from ._array_object import Array , CPU_DEVICE
54
49
55
50
_check_valid_dtype (dtype )
56
- if device not in ["cpu" , None ]:
51
+ _np_dtype = None
52
+ if dtype is not None :
53
+ _np_dtype = dtype ._np_dtype
54
+ if device not in [CPU_DEVICE , None ]:
57
55
raise ValueError (f"Unsupported device { device !r} " )
58
56
if copy in (False , np ._CopyMode .IF_NEEDED ):
59
57
# Note: copy=False is not yet implemented in np.asarray
@@ -62,13 +60,13 @@ def asarray(
62
60
if dtype is not None and obj .dtype != dtype :
63
61
copy = True
64
62
if copy in (True , np ._CopyMode .ALWAYS ):
65
- return Array ._new (np .array (obj ._array , copy = True , dtype = dtype ))
63
+ return Array ._new (np .array (obj ._array , copy = True , dtype = _np_dtype ))
66
64
return obj
67
65
if dtype is None and isinstance (obj , int ) and (obj > 2 ** 64 or obj < - (2 ** 63 )):
68
66
# Give a better error message in this case. NumPy would convert this
69
67
# to an object array. TODO: This won't handle large integers in lists.
70
68
raise OverflowError ("Integer out of bounds for array dtypes" )
71
- res = np .asarray (obj , dtype = dtype )
69
+ res = np .asarray (obj , dtype = _np_dtype )
72
70
return Array ._new (res )
73
71
74
72
@@ -86,11 +84,13 @@ def arange(
86
84
87
85
See its docstring for more information.
88
86
"""
89
- from ._array_object import Array
87
+ from ._array_object import Array , CPU_DEVICE
90
88
91
89
_check_valid_dtype (dtype )
92
- if device not in ["cpu" , None ]:
90
+ if device not in [CPU_DEVICE , None ]:
93
91
raise ValueError (f"Unsupported device { device !r} " )
92
+ if dtype is not None :
93
+ dtype = dtype ._np_dtype
94
94
return Array ._new (np .arange (start , stop = stop , step = step , dtype = dtype ))
95
95
96
96
@@ -105,11 +105,13 @@ def empty(
105
105
106
106
See its docstring for more information.
107
107
"""
108
- from ._array_object import Array
108
+ from ._array_object import Array , CPU_DEVICE
109
109
110
110
_check_valid_dtype (dtype )
111
- if device not in ["cpu" , None ]:
111
+ if device not in [CPU_DEVICE , None ]:
112
112
raise ValueError (f"Unsupported device { device !r} " )
113
+ if dtype is not None :
114
+ dtype = dtype ._np_dtype
113
115
return Array ._new (np .empty (shape , dtype = dtype ))
114
116
115
117
@@ -121,11 +123,13 @@ def empty_like(
121
123
122
124
See its docstring for more information.
123
125
"""
124
- from ._array_object import Array
126
+ from ._array_object import Array , CPU_DEVICE
125
127
126
128
_check_valid_dtype (dtype )
127
- if device not in ["cpu" , None ]:
129
+ if device not in [CPU_DEVICE , None ]:
128
130
raise ValueError (f"Unsupported device { device !r} " )
131
+ if dtype is not None :
132
+ dtype = dtype ._np_dtype
129
133
return Array ._new (np .empty_like (x ._array , dtype = dtype ))
130
134
131
135
@@ -143,11 +147,13 @@ def eye(
143
147
144
148
See its docstring for more information.
145
149
"""
146
- from ._array_object import Array
150
+ from ._array_object import Array , CPU_DEVICE
147
151
148
152
_check_valid_dtype (dtype )
149
- if device not in ["cpu" , None ]:
153
+ if device not in [CPU_DEVICE , None ]:
150
154
raise ValueError (f"Unsupported device { device !r} " )
155
+ if dtype is not None :
156
+ dtype = dtype ._np_dtype
151
157
return Array ._new (np .eye (n_rows , M = n_cols , k = k , dtype = dtype ))
152
158
153
159
@@ -169,15 +175,17 @@ def full(
169
175
170
176
See its docstring for more information.
171
177
"""
172
- from ._array_object import Array
178
+ from ._array_object import Array , CPU_DEVICE
173
179
174
180
_check_valid_dtype (dtype )
175
- if device not in ["cpu" , None ]:
181
+ if device not in [CPU_DEVICE , None ]:
176
182
raise ValueError (f"Unsupported device { device !r} " )
177
183
if isinstance (fill_value , Array ) and fill_value .ndim == 0 :
178
184
fill_value = fill_value ._array
185
+ if dtype is not None :
186
+ dtype = dtype ._np_dtype
179
187
res = np .full (shape , fill_value , dtype = dtype )
180
- if res .dtype not in _all_dtypes :
188
+ if _DType ( res .dtype ) not in _all_dtypes :
181
189
# This will happen if the fill value is not something that NumPy
182
190
# coerces to one of the acceptable dtypes.
183
191
raise TypeError ("Invalid input to full" )
@@ -197,13 +205,15 @@ def full_like(
197
205
198
206
See its docstring for more information.
199
207
"""
200
- from ._array_object import Array
208
+ from ._array_object import Array , CPU_DEVICE
201
209
202
210
_check_valid_dtype (dtype )
203
- if device not in ["cpu" , None ]:
211
+ if device not in [CPU_DEVICE , None ]:
204
212
raise ValueError (f"Unsupported device { device !r} " )
213
+ if dtype is not None :
214
+ dtype = dtype ._np_dtype
205
215
res = np .full_like (x ._array , fill_value , dtype = dtype )
206
- if res .dtype not in _all_dtypes :
216
+ if _DType ( res .dtype ) not in _all_dtypes :
207
217
# This will happen if the fill value is not something that NumPy
208
218
# coerces to one of the acceptable dtypes.
209
219
raise TypeError ("Invalid input to full_like" )
@@ -225,11 +235,13 @@ def linspace(
225
235
226
236
See its docstring for more information.
227
237
"""
228
- from ._array_object import Array
238
+ from ._array_object import Array , CPU_DEVICE
229
239
230
240
_check_valid_dtype (dtype )
231
- if device not in ["cpu" , None ]:
241
+ if device not in [CPU_DEVICE , None ]:
232
242
raise ValueError (f"Unsupported device { device !r} " )
243
+ if dtype is not None :
244
+ dtype = dtype ._np_dtype
233
245
return Array ._new (np .linspace (start , stop , num , dtype = dtype , endpoint = endpoint ))
234
246
235
247
@@ -264,11 +276,13 @@ def ones(
264
276
265
277
See its docstring for more information.
266
278
"""
267
- from ._array_object import Array
279
+ from ._array_object import Array , CPU_DEVICE
268
280
269
281
_check_valid_dtype (dtype )
270
- if device not in ["cpu" , None ]:
282
+ if device not in [CPU_DEVICE , None ]:
271
283
raise ValueError (f"Unsupported device { device !r} " )
284
+ if dtype is not None :
285
+ dtype = dtype ._np_dtype
272
286
return Array ._new (np .ones (shape , dtype = dtype ))
273
287
274
288
@@ -280,11 +294,13 @@ def ones_like(
280
294
281
295
See its docstring for more information.
282
296
"""
283
- from ._array_object import Array
297
+ from ._array_object import Array , CPU_DEVICE
284
298
285
299
_check_valid_dtype (dtype )
286
- if device not in ["cpu" , None ]:
300
+ if device not in [CPU_DEVICE , None ]:
287
301
raise ValueError (f"Unsupported device { device !r} " )
302
+ if dtype is not None :
303
+ dtype = dtype ._np_dtype
288
304
return Array ._new (np .ones_like (x ._array , dtype = dtype ))
289
305
290
306
@@ -327,11 +343,13 @@ def zeros(
327
343
328
344
See its docstring for more information.
329
345
"""
330
- from ._array_object import Array
346
+ from ._array_object import Array , CPU_DEVICE
331
347
332
348
_check_valid_dtype (dtype )
333
- if device not in ["cpu" , None ]:
349
+ if device not in [CPU_DEVICE , None ]:
334
350
raise ValueError (f"Unsupported device { device !r} " )
351
+ if dtype is not None :
352
+ dtype = dtype ._np_dtype
335
353
return Array ._new (np .zeros (shape , dtype = dtype ))
336
354
337
355
@@ -343,9 +361,11 @@ def zeros_like(
343
361
344
362
See its docstring for more information.
345
363
"""
346
- from ._array_object import Array
364
+ from ._array_object import Array , CPU_DEVICE
347
365
348
366
_check_valid_dtype (dtype )
349
- if device not in ["cpu" , None ]:
367
+ if device not in [CPU_DEVICE , None ]:
350
368
raise ValueError (f"Unsupported device { device !r} " )
369
+ if dtype is not None :
370
+ dtype = dtype ._np_dtype
351
371
return Array ._new (np .zeros_like (x ._array , dtype = dtype ))
0 commit comments