Description
This RFC proposes the addition of a new API in the array API specification for taking values from an input array by matching one-dimensional index and data slices.
Overview
Based on array comparison data, the API is available across most major array libraries in the PyData ecosystem.
take_along_axis
was previously discussed in #177 as a potential standardization candidate and has been mentioned in downstream usage. As indexing with multidimensional integer arrays (see #669) is not yet supported in the specification, the specification lacks a means to concisely select multiple values along multiple one-dimensional slices. This RFC aims to fill this gap.
Additionally, even with advanced indexing, replicating take_along_axis
is more verbose and trickier to get right. For example, consider
In [1]: import numpy as np
In [2]: a = np.array([[10,30,20], [60,40,50]])
In [3]: a
Out[3]:
array([[10, 30, 20],
[60, 40, 50]])
In [4]: indices = np.array([[2,0,1],[1,2,0]])
In [5]: indices
Out[5]:
array([[2, 0, 1],
[1, 2, 0]])
In [6]: np.take_along_axis(a, indices, axis=1)
Out[6]:
array([[20, 10, 30],
[40, 50, 60]])
To replicate with advanced indexing,
In [7]: i0 = np.arange(a.shape[0])[:, np.newaxis]
In [8]: a[i0, indices]
Out[8]:
array([[20, 10, 30],
[40, 50, 60]])
where we need to create an integer index (with expanded dimensions) for the first dimension, which can then be broadcast against the integer index indices
. Especially for higher order dimensions, replication of take_along_axis
becomes even more verbose. E.g., for a 3-dimensional array,
a = np.random.rand(3, 4, 5)
indices = np.random.randint(5, size=(3, 4, 2))
i0 = np.arange(a.shape[0])[:, np.newaxis, np.newaxis]
i1 = np.arange(a.shape[1])[np.newaxis, :, np.newaxis]
result = a[i0, i1, indices]
In general, while "advanced indexing" can be used for replicating take_along_axis
, doing so is less ergonomic and has a higher likelihood of mistakes.
Prior art
- NumPy: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
- CuPy: https://docs.cupy.dev/en/stable/reference/generated/cupy.take_along_axis.html
- Dask: (does not currently implement)
- PyTorch: https://pytorch.org/docs/stable/generated/torch.take_along_dim.html
- Named
take_along_dim
, rather thantake_along_axis
- Named
- TensorFlow: https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/take_along_axis
- JAX: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.take_along_axis.html
- Supports additional
mode
andfill
kwargs
- Supports additional
Proposal
def take_along_axis(x: array, indices: array, /, axis: int) -> array
Notes
- For the Array API compat layer, for libraries without this functionality, the workaround proposed in ENH: Add a function that performs the indexing needed to map argsort to sort numpy/numpy#8708 (comment) and implemented in ENH: Add (put|take)_along_axis numpy/numpy#11105 can be done in pure Python, without needing any additional Array API functions.
Questions
- NumPy and its kin allow
axis
to beNone
in order to indicate that the input arrayx
should be flattened prior to indexing. This allows consistency with NumPy'ssort
andargsort
functions. However, the specification forsort
andargsort
does not supportNone
(i.e., flattening). Accordingly, this RFC does not propose supportingaxis=None
. Are we okay with this? - NumPy and kin allow keyword and positional arguments. This RFC makes
x
andindices
positional-only and allowsaxis
to be both positional or keyword. Any concerns? - As elsewhere with this specification, presumably PyTorch will be okay aliasing
take_along_dim
astake_along_axis
anddim
asaxis
to ensure spec compliance?
Metadata
Metadata
Assignees
Labels
Type
Projects
Status