Array API Standard Support: signal#
This page explains some caveats of the signal module and provides (currently
incomplete) tables about the
CPU,
GPU and
JIT support.
Caveats#
JAX and CuPy provide alternative
implementations for some signal functions. When such a function is called, a
decorator decides which implementation to use by inspecting the xp parameter.
Hence, there can be, especially during CI testing, discrepancies in behavior between the default NumPy-based implementation and the JAX and CuPy backends. Skipping the incompatible backends in unit tests, as described in the Adding tests section, is the currently recommended workaround.
The functions are decorated by the code in file
scipy/signal/_support_alternative_backends.py:
1import functools
2import types
3from scipy._lib._array_api import (
4 is_cupy, is_jax, scipy_namespace_for, SCIPY_ARRAY_API, xp_capabilities
5)
6
7from ._signal_api import * # noqa: F403
8from . import _signal_api
9from . import _delegators
10__all__ = _signal_api.__all__
11
12
13MODULE_NAME = 'signal'
14
15# jax.scipy.signal has only partial coverage of scipy.signal, so we keep the list
16# of functions we can delegate to JAX
17# https://jax.readthedocs.io/en/latest/jax.scipy.html
18JAX_SIGNAL_FUNCS = [
19 'fftconvolve', 'convolve', 'convolve2d', 'correlate', 'correlate2d',
20 'csd', 'detrend', 'istft', 'welch'
21]
22
23# some cupyx.scipy.signal functions are incompatible with their scipy counterparts
24CUPY_BLACKLIST = [
25 'abcd_normalize', 'bessel', 'besselap', 'envelope', 'get_window', 'lfilter_zi',
26 'sosfilt_zi', 'remez',
27]
28
29# freqz_sos is a sosfreqz rename, and cupy does not have the new name yet (in v13.x)
30CUPY_RENAMES = {'freqz_sos': 'sosfreqz'}
31
32
33def delegate_xp(delegator, module_name):
34 def inner(func):
35 @functools.wraps(func)
36 def wrapper(*args, **kwds):
37 try:
38 xp = delegator(*args, **kwds)
39 except TypeError:
40 # object arrays
41 if func.__name__ == "tf2ss":
42 import numpy as np
43 xp = np
44 else:
45 raise
46
47 # try delegating to a cupyx/jax namesake
48 if is_cupy(xp) and func.__name__ not in CUPY_BLACKLIST:
49 func_name = CUPY_RENAMES.get(func.__name__, func.__name__)
50
51 # https://github.com/cupy/cupy/issues/8336
52 import importlib
53 cupyx_module = importlib.import_module(f"cupyx.scipy.{module_name}")
54 cupyx_func = getattr(cupyx_module, func_name)
55 kwds.pop('xp', None)
56 return cupyx_func(*args, **kwds)
57 elif is_jax(xp) and func.__name__ in JAX_SIGNAL_FUNCS:
58 spx = scipy_namespace_for(xp)
59 jax_module = getattr(spx, module_name)
60 jax_func = getattr(jax_module, func.__name__)
61 kwds.pop('xp', None)
62 return jax_func(*args, **kwds)
63 else:
64 # the original function
65 return func(*args, **kwds)
66 return wrapper
67 return inner
68
69
70# Although most of these functions currently exist in CuPy and some in JAX,
71# there are no alternative backend tests for any of them in the current
72# test suite. Each will be documented as np_only until tests are added.
73untested = {
74 "argrelextrema",
75 "argrelmax",
76 "argrelmin",
77 "band_stop_obj",
78 "check_NOLA",
79 "chirp",
80 "coherence",
81 "csd",
82 "czt_points",
83 "dbode",
84 "dfreqresp",
85 "dlsim",
86 "dstep",
87 "find_peaks",
88 "find_peaks_cwt",
89 "freqresp",
90 "gausspulse",
91 "lombscargle",
92 "lsim",
93 "max_len_seq",
94 "peak_prominences",
95 "peak_widths",
96 "periodogram",
97 "place_pols",
98 "sawtooth",
99 "sepfir2d",
100 "square",
101 "ss2tf",
102 "ss2zpk",
103 "step",
104 "sweep_poly",
105 "symiirorder1",
106 "symiirorder2",
107 "tf2ss",
108 "unit_impulse",
109 "welch",
110 "zoom_fft",
111 "zpk2ss",
112}
113
114
115def get_default_capabilities(func_name, delegator):
116 if delegator is None or func_name in untested:
117 return xp_capabilities(np_only=True)
118 return xp_capabilities()
119
120bilinear_extra_note = \
121 """CuPy does not accept complex inputs.
122
123 """
124
125uses_choose_conv_extra_note = \
126 """CuPy does not support inputs with ``ndim>1`` when ``method="auto"``
127 but does support higher dimensional arrays for ``method="direct"``
128 and ``method="fft"``.
129
130 """
131
132resample_poly_extra_note = \
133 """CuPy only supports ``padtype="constant"``.
134
135 """
136
137upfirdn_extra_note = \
138 """CuPy only supports ``mode="constant"`` and ``cval=0.0``.
139
140 """
141
142xord_extra_note = \
143 """The ``torch`` backend on GPU does not support the case where
144 `wp` and `ws` specify a Bandstop filter.
145
146 """
147
148convolve2d_extra_note = \
149 """The JAX backend only supports ``boundary="fill"`` and ``fillvalue=0``.
150
151 """
152
153zpk2tf_extra_note = \
154 """The CuPy and JAX backends both support only 1d input.
155
156 """
157
158capabilities_overrides = {
159 "bessel": xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True),
160 "bilinear": xp_capabilities(cpu_only=True, exceptions=["cupy"],
161 jax_jit=False, allow_dask_compute=True,
162 reason="Uses np.polynomial.Polynomial",
163 extra_note=bilinear_extra_note),
164 "bilinear_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
165 jax_jit=False, allow_dask_compute=True),
166 "butter": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
167 allow_dask_compute=True),
168 "buttord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
169 jax_jit=False, allow_dask_compute=True,
170 extra_note=xord_extra_note),
171 "cheb1ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
172 jax_jit=False, allow_dask_compute=True,
173 extra_note=xord_extra_note),
174 "cheb2ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
175 jax_jit=False, allow_dask_compute=True,
176 extra_note=xord_extra_note),
177 "cheby1": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
178 allow_dask_compute=True),
179
180 "cheby2": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
181 allow_dask_compute=True),
182 "cont2discrete": xp_capabilities(np_only=True, exceptions=["cupy"]),
183 "convolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
184 allow_dask_compute=True,
185 extra_note=uses_choose_conv_extra_note),
186 "convolve2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
187 allow_dask_compute=True,
188 extra_note=convolve2d_extra_note),
189 "correlate": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
190 allow_dask_compute=True,
191 extra_note=uses_choose_conv_extra_note),
192 "correlate2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
193 allow_dask_compute=True,
194 extra_note=convolve2d_extra_note),
195 "correlation_lags": xp_capabilities(out_of_scope=True),
196 "cspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
197 jax_jit=False, allow_dask_compute=True),
198 "cspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
199 jax_jit=False, allow_dask_compute=True),
200 "cspline2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
201 jax_jit=False, allow_dask_compute=True),
202 "czt": xp_capabilities(np_only=True, exceptions=["cupy"]),
203 "deconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy"],
204 allow_dask_compute=True,
205 skip_backends=[("jax.numpy", "item assignment")]),
206 "decimate": xp_capabilities(np_only=True, exceptions=["cupy"]),
207 "detrend": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
208 allow_dask_compute=True),
209 "dimpulse": xp_capabilities(np_only=True, exceptions=["cupy"]),
210 "dlti": xp_capabilities(np_only=True,
211 reason="works in CuPy but delegation isn't set up yet"),
212 "ellip": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
213 allow_dask_compute=True,
214 reason="scipy.special.ellipk"),
215 "ellipord": xp_capabilities(cpu_only=True, exceptions=["cupy"],
216 jax_jit=False, allow_dask_compute=True,
217 reason="scipy.special.ellipk"),
218 "findfreqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
219 jax_jit=False, allow_dask_compute=True),
220 "firls": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False,
221 reason="lstsq"),
222 "firwin": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
223 jax_jit=False, allow_dask_compute=True),
224 "firwin2": xp_capabilities(cpu_only=True, exceptions=["cupy"],
225 jax_jit=False, allow_dask_compute=True,
226 reason="firwin uses np.interp"),
227 "fftconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"]),
228 "freqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
229 jax_jit=False, allow_dask_compute=True),
230 "freqs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
231 jax_jit=False, allow_dask_compute=True),
232 "freqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
233 jax_jit=False, allow_dask_compute=True),
234 "freqz_sos": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
235 jax_jit=False, allow_dask_compute=True),
236 "group_delay": xp_capabilities(cpu_only=True, exceptions=["cupy"],
237 jax_jit=False, allow_dask_compute=True),
238 "hilbert": xp_capabilities(
239 cpu_only=True, exceptions=["cupy", "torch"],
240 skip_backends=[("jax.numpy", "item assignment")],
241 ),
242 "hilbert2": xp_capabilities(
243 cpu_only=True, exceptions=["cupy", "torch"],
244 skip_backends=[("jax.numpy", "item assignment")],
245 ),
246 "invres": xp_capabilities(np_only=True, exceptions=["cupy"]),
247 "invresz": xp_capabilities(np_only=True, exceptions=["cupy"]),
248 "iircomb": xp_capabilities(xfail_backends=[("jax.numpy", "inaccurate")]),
249 "iirfilter": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
250 jax_jit=False, allow_dask_compute=True),
251 "kaiser_atten": xp_capabilities(
252 out_of_scope=True, reason="scalars in, scalars out"
253 ),
254 "kaiser_beta": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
255 "kaiserord": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
256 "lfilter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
257 allow_dask_compute=True, jax_jit=False),
258 "lfilter_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
259 jax_jit=False),
260 "lfiltic": xp_capabilities(cpu_only=True, exceptions=["cupy"],
261 allow_dask_compute=True),
262 "lp2bp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
263 allow_dask_compute=True,
264 skip_backends=[("jax.numpy", "in-place item assignment")]),
265 "lp2bp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
266 allow_dask_compute=True, jax_jit=False),
267 "lp2bs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
268 allow_dask_compute=True,
269 skip_backends=[("jax.numpy", "in-place item assignment")]),
270 "lp2bs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
271 allow_dask_compute=True, jax_jit=False),
272 "lp2lp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
273 allow_dask_compute=True, jax_jit=False),
274 "lp2lp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
275 allow_dask_compute=True, jax_jit=False),
276 "lp2hp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
277 allow_dask_compute=True,
278 skip_backends=[("jax.numpy", "in-place item assignment")]),
279 "lp2hp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
280 allow_dask_compute=True, jax_jit=False),
281 "lti": xp_capabilities(np_only=True,
282 reason="works in CuPy but delegation isn't set up yet"),
283 "medfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
284 allow_dask_compute=True, jax_jit=False,
285 reason="uses scipy.ndimage.rank_filter"),
286 "medfilt2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
287 allow_dask_compute=True, jax_jit=False,
288 reason="c extension module"),
289 "minimum_phase": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
290 allow_dask_compute=True, jax_jit=False),
291 "normalize": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
292 jax_jit=False, allow_dask_compute=True),
293 "oaconvolve": xp_capabilities(
294 cpu_only=True, exceptions=["cupy", "torch"],
295 skip_backends=[("jax.numpy", "fails all around")],
296 xfail_backends=[("dask.array", "wrong answer")],
297 ),
298 "order_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
299 allow_dask_compute=True, jax_jit=False,
300 reason="uses scipy.ndimage.rank_filter"),
301 "qspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
302 jax_jit=False, allow_dask_compute=True),
303 "qspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
304 jax_jit=False, allow_dask_compute=True),
305 "qspline2d": xp_capabilities(np_only=True, exceptions=["cupy"]),
306 "remez": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False),
307 "resample": xp_capabilities(
308 cpu_only=True, exceptions=["cupy"],
309 skip_backends=[
310 ("dask.array", "XXX something in dask"),
311 ("jax.numpy", "XXX: immutable arrays"),
312 ]
313 ),
314 "resample_poly": xp_capabilities(
315 cpu_only=True, exceptions=["cupy"],
316 jax_jit=False, skip_backends=[("dask.array", "XXX something in dask")],
317 extra_note=resample_poly_extra_note,
318 ),
319 "residue": xp_capabilities(np_only=True, exceptions=["cupy"]),
320 "residuez": xp_capabilities(np_only=True, exceptions=["cupy"]),
321 "savgol_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
322 jax_jit=False,
323 reason="convolve1d is cpu-only"),
324 "sepfir2d": xp_capabilities(np_only=True),
325 "sos2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
326 allow_dask_compute=True),
327 "sos2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
328 allow_dask_compute=True),
329 "sosfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
330 allow_dask_compute=True),
331 "sosfiltfilt": xp_capabilities(
332 cpu_only=True, exceptions=["cupy"],
333 skip_backends=[
334 (
335 "dask.array",
336 "sosfiltfilt directly sets shape attributes on arrays"
337 " which dask doesn't like"
338 ),
339 ("torch", "negative strides"),
340 ("jax.numpy", "sosfilt works in-place"),
341 ],
342 ),
343 "sosfreqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
344 jax_jit=False, allow_dask_compute=True),
345 "spline_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
346 jax_jit=False, allow_dask_compute=True),
347 "tf2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
348 allow_dask_compute=True),
349 "tf2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
350 allow_dask_compute=True),
351 "unique_roots": xp_capabilities(np_only=True, exceptions=["cupy"]),
352 "upfirdn": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
353 allow_dask_compute=True,
354 reason="Cython implementation",
355 extra_note=upfirdn_extra_note),
356 "vectorstrength": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
357 allow_dask_compute=True, jax_jit=False),
358 "wiener": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
359 allow_dask_compute=True, jax_jit=False,
360 reason="uses scipy.signal.correlate"),
361 "zpk2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
362 allow_dask_compute=True),
363 "zpk2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
364 allow_dask_compute=True,
365 extra_note=zpk2tf_extra_note),
366 "spectrogram": xp_capabilities(out_of_scope=True), # legacy
367 "stft": xp_capabilities(out_of_scope=True), # legacy
368 "istft": xp_capabilities(out_of_scope=True), # legacy
369 "check_COLA": xp_capabilities(out_of_scope=True), # legacy
370}
371
372
373# ### decorate ###
374for obj_name in _signal_api.__all__:
375 bare_obj = getattr(_signal_api, obj_name)
376 delegator = getattr(_delegators, obj_name + "_signature", None)
377
378 if SCIPY_ARRAY_API and delegator is not None:
379 f = delegate_xp(delegator, MODULE_NAME)(bare_obj)
380 else:
381 f = bare_obj
382
383 if not isinstance(f, types.ModuleType):
384 capabilities = capabilities_overrides.get(
385 obj_name, get_default_capabilities(obj_name, delegator)
386 )
387 f = capabilities(f)
388
389 # add the decorated function to the namespace, to be imported in __init__.py
390 vars()[obj_name] = f
Note that a function will only be decorated if the environment variable
SCIPY_ARRAY_API is set and its signature is listed in the file
scipy/signal/_delegators.py. E.g., for firwin, the signature
function looks like this:
339def firwin_signature(numtaps, cutoff, *args, **kwds):
340 if isinstance(cutoff, int | float):
341 xp = np_compat
342 else:
343 xp = array_namespace(cutoff)
344 return xp
Support on CPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
torch |
jax |
dask |
|---|---|---|---|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
Support on GPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
cupy |
torch |
jax |
|---|---|---|---|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
Support with JIT#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
jax |
|---|---|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✔️ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
N/A |
|
N/A |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |