Closed
Description
This RFC requests to include a new API in the array API specification for the purpose of computing the cumulative sum.
Overview
Based on array comparison data, the API is available in all the libraries in the PyData ecosystem.
Prior art
- NumPy: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
- PyTorch: https://pytorch.org/docs/stable/generated/torch.cumsum.html
- MXNet: https://mxnet.apache.org/versions/1.6/api/r/docs/api/mx.nd.cumsum.html
- TensorFlow: https://www.tensorflow.org/api_docs/python/tf/math/cumsum
Proposal:
def cumsum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None) -> array
dtype
kwarg is for consistency withsum
et al
Activity
asmeurer commentedon Feb 15, 2023
There's also
np.add.accumulate
which may be relevant when looking at the comparison data.rgommers commentedon Feb 15, 2023
Based on previous experience with complaints about bad naming (see, e.g., scipy/scipy#12924 and https://www.reddit.com/r/programminghorror/comments/j6sd61/i_was_just_looking_at_the_documentation_for/) I would very much prefer not to enshrine
cumsum
andcumprod
(as proposed in gh-598) names in this API standard.This is pretty niche functionality and arguably not "core" enough to implementing an array library for it to be in this standard at all, so I'd vote for leaving it out completely. In a compat layer it could perhaps be named
cumulative_*
like SciPy did with a few methods, if it's desired for these functions to be there.soraros commentedon Feb 16, 2023
Niche it may be, I still think
cumsum
is pretty useful, as many indexing tricks depend on it. One of such trick is turning "pauses" into "stairs":This is especially true when one works with statically-shaped system like JAX where these tricks are more or less required. Its usages in the implementation of
jax.numpy.nonzero
andjax.numpy.repeat
are pretty typical. These are also good examples: ex1, ex2, ex3.rgommers commentedon Feb 16, 2023
@soraros it seems like all that usage of
cumsum
in JAX is just a cumbersome way of working around not having boolean indexing support for in-place ops? And JAX should instead add some other primitive that's more suitable, rather than letting all users manually construct expanded integer index arrays to then use with.at[idx_cumsum].xxx
?That kind of JAX code looks very bad, and as a motivator for writing portable code based on a standard I don't think it's a positive. JAX could add support for
x[ix_bool] += 1
tomorrow, and imho they should rather than make users writex.at[cumsum(ix_bool)].add(1)
. It's 100% equivalent. And for the "expanded boolean indexing" where the desired result is[0, 2, 4, 0, 0]
as an output, it needs a new builtin function likejnp.bool_index(xs, conds)
. It has little to do with cumulative sums for data/statistical purposes, it's a type of indexing.There's a known blocker (brought up by the JAX team before) for adding more in-place operator support beyond what they have now, which is that this standard should be able to better guarantee that
+=
doesn't modify views in NumPy et al. But if we can make that guarantee at some point, then we have what we want here.soraros commentedon Feb 16, 2023
@rgommers Yes and no. It is a cumbersome workaround, but I also think the problem is more fundamental than that. JAX is essentially a front-end for XLA, and the primitives provided by XLA (for now) require static shape. So the line that actually go wrong is
Note this code does work in JAX, though not jittable, for we don't know its output shape. Let's pretend
x[ix_bool] += 1
is syntax sugar forx = x + where(ix_bool, 1, 0)
(which works in JAX) for a moment. The same problem appears when we wantx[ix_bool] += [1, 3, 5]
. Again, we somehow need to know the shape of the rhs, which is equivalent to know the shape ofxs[ix_bool]
as in the last example.So what we really work around is the static shape requirement (recall the need of a
size
parameter fornonzero
), which is not exclusively JAX.Now, for something a bit off-topic.
I think the JAX style functional syntax
a = a.at[...].set(...)
for in-place operation looks (and arguably works) better thannumpy
, and I'd really like to have it for array api. Some pros:Numba
as well.rgommers commentedon Feb 17, 2023
@soraros thanks for all this detail, it's very interesting actually. I think there's something to be said indeed for
.at
- especially if JAX can make a new cleaner API for it that doesn't rely on things like explicit use ofcumsum
. This discussion is effectively a follow-up to gh-84 (EDIT: and gh-24). What do you think about opening a new issue to continue this discussion, especially your "for something a bit off-topic" part?soraros commentedon Feb 17, 2023
@rgommers Glad you find the exchange interesting!
gh-84 is indeed interesting, I will read that throughly later.
I'm not sure how to proceed regarding opening a new issue to continue the discussion though (not sure about the exact topic and/or scope you have in mind). If it's not too much trouble, could you open a new issue so I (or maybe you) can move the relevant part there? Thanks in advance!
asmeurer commentedon Feb 17, 2023
How do you implement cumulative sum using only array API functions (and without using a Python loop)?
oleksandr-pavlyk commentedon Feb 17, 2023
One inefficient possibility:
[-]Add `cumsum` to the standard[/-][+]Add support for computing the cumulative sum to the standard[/+].at
for simulating in-place ops #609rgommers commentedon Mar 9, 2023
Done now in gh-609 - sorry for the delay! I spent a lot of time refreshing my memory and trying to put together something more coherent. But it's tricky; the need for new API in JAX to avoid
cumsum
seems clear, but if it was easy to exactly define how it'd look, the JAX devs would have done that by now I guess:)shoyer commentedon Apr 3, 2023
I think cumulative sum is a rather fundamental array operation and we should include it in the standard. None of the typical reasons for omitting a function from the API standard apply here:
axis
is omitted)To give a few other examples of how I've used it:
As a reference point on popularity, I see about twice as many uses of
np.cumsum
in Google's codebase as fornp.roll
ornp.var
.I'll also second Ralf's point on calling it
cumulative_sum
rather thancumsum
.WarrenWeckesser commentedon Apr 3, 2023
Before locking in the specification, it would be worthwhile taking a look at numpy/numpy#6044, and numpy/numpy#14542 that is referenced from numpy/numpy#6044. My comment in numpy/numpy#14542 provides some evidence for the usefulness of allowing the result to include a prepended 0. (More generally, for other cumulative operations such as
cumprod
, the identity of the operation would be prepended.)15 remaining items