Skip to content

RFC: add take_along_axis to take values along a specified dimension #808

Closed
@kgryte

Description

@kgryte

This RFC proposes the addition of a new API in the array API specification for taking values from an input array by matching one-dimensional index and data slices.

Overview

Based on array comparison data, the API is available across most major array libraries in the PyData ecosystem.

take_along_axis was previously discussed in #177 as a potential standardization candidate and has been mentioned in downstream usage. As indexing with multidimensional integer arrays (see #669) is not yet supported in the specification, the specification lacks a means to concisely select multiple values along multiple one-dimensional slices. This RFC aims to fill this gap.

Additionally, even with advanced indexing, replicating take_along_axis is more verbose and trickier to get right. For example, consider

In [1]: import numpy as np

In [2]: a = np.array([[10,30,20], [60,40,50]])

In [3]: a
Out[3]:
array([[10, 30, 20],
       [60, 40, 50]])

In [4]: indices = np.array([[2,0,1],[1,2,0]])

In [5]: indices
Out[5]:
array([[2, 0, 1],
       [1, 2, 0]])

In [6]: np.take_along_axis(a, indices, axis=1)
Out[6]:
array([[20, 10, 30],
       [40, 50, 60]])

To replicate with advanced indexing,

In [7]: i0 = np.arange(a.shape[0])[:, np.newaxis]

In [8]: a[i0, indices]
Out[8]:
array([[20, 10, 30],
       [40, 50, 60]])

where we need to create an integer index (with expanded dimensions) for the first dimension, which can then be broadcast against the integer index indices. Especially for higher order dimensions, replication of take_along_axis becomes even more verbose. E.g., for a 3-dimensional array,

a = np.random.rand(3, 4, 5)
indices = np.random.randint(5, size=(3, 4, 2))

i0 = np.arange(a.shape[0])[:, np.newaxis, np.newaxis]
i1 = np.arange(a.shape[1])[np.newaxis, :, np.newaxis]

result = a[i0, i1, indices]

In general, while "advanced indexing" can be used for replicating take_along_axis, doing so is less ergonomic and has a higher likelihood of mistakes.

Prior art

Proposal

def take_along_axis(x: array, indices: array, /, axis: int) -> array

Notes

Questions

  • NumPy and its kin allow axis to be None in order to indicate that the input array x should be flattened prior to indexing. This allows consistency with NumPy's sort and argsort functions. However, the specification for sort and argsort does not support None (i.e., flattening). Accordingly, this RFC does not propose supporting axis=None. Are we okay with this?
  • NumPy and kin allow keyword and positional arguments. This RFC makes x and indices positional-only and allows axis to be both positional or keyword. Any concerns?
  • As elsewhere with this specification, presumably PyTorch will be okay aliasing take_along_dim as take_along_axis and dim as axis to ensure spec compliance?

Metadata

Metadata

Assignees

No one assigned

    Labels

    API extensionAdds new functions or objects to the API.Needs DiscussionNeeds further discussion.RFCRequest for comments. Feature requests and proposed changes.topic: IndexingArray indexing.

    Type

    No type

    Projects

    Status

    Stage 1

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions