Skip to content

Add support for computing the cumulative sum to the standard #597

Closed
@steff456

Description

@steff456
Member

This RFC requests to include a new API in the array API specification for the purpose of computing the cumulative sum.

Overview

Based on array comparison data, the API is available in all the libraries in the PyData ecosystem.

Prior art

Proposal:

def cumsum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None) -> array
  • dtype kwarg is for consistency with sum et al

cc @oleksandr-pavlyk

Activity

asmeurer

asmeurer commented on Feb 15, 2023

@asmeurer
Member

There's also np.add.accumulate which may be relevant when looking at the comparison data.

rgommers

rgommers commented on Feb 15, 2023

@rgommers
Member

Based on previous experience with complaints about bad naming (see, e.g., scipy/scipy#12924 and https://www.reddit.com/r/programminghorror/comments/j6sd61/i_was_just_looking_at_the_documentation_for/) I would very much prefer not to enshrine cumsum and cumprod (as proposed in gh-598) names in this API standard.

This is pretty niche functionality and arguably not "core" enough to implementing an array library for it to be in this standard at all, so I'd vote for leaving it out completely. In a compat layer it could perhaps be named cumulative_* like SciPy did with a few methods, if it's desired for these functions to be there.

soraros

soraros commented on Feb 16, 2023

@soraros
Contributor

Niche it may be, I still think cumsum is pretty useful, as many indexing tricks depend on it. One of such trick is turning "pauses" into "stairs":

p = jnp.zeros(10, int).at[jnp.array([1, 4, 8])].set(1)
# [0 1 0 0 1 0 0 0 1 0]
s = p.cumsum()
# [0 1 1 1 2 2 2 2 3 3]

This is especially true when one works with statically-shaped system like JAX where these tricks are more or less required. Its usages in the implementation of jax.numpy.nonzero and jax.numpy.repeat are pretty typical. These are also good examples: ex1, ex2, ex3.

rgommers

rgommers commented on Feb 16, 2023

@rgommers
Member

This is especially true when one works with statically-shaped system like JAX where these tricks are more or less required.

@soraros it seems like all that usage of cumsum in JAX is just a cumbersome way of working around not having boolean indexing support for in-place ops? And JAX should instead add some other primitive that's more suitable, rather than letting all users manually construct expanded integer index arrays to then use with .at[idx_cumsum].xxx?

# For an array `xs` and a boolean index `conds` of the same shape
# example JAX expression for `jax_filter`:
>>> xs = jnp.arange(5)
>>> conds = jnp.array([True, False, True, False, True])
>>> cumsum = jnp.cumsum(conds)
>>> cumsum
Array([1, 1, 2, 2, 3], dtype=int32)
>>> jnp.zeros_like(xs).at[cumsum - 1].add(jnp.where(conds, xs, 0))
Array([0, 2, 4, 0, 0], dtype=int32)

# NumPy:
>>> xs = np.arange(5)
>>> conds = np.array([True, False, True, False, True])
>>> xs[conds]
array([0, 2, 4])
>>> np.zeros_like(xs) + np.where(conds, xs, 0)  # note: not identical to JAXs `.at`, 0-padding not at the end
array([0, 0, 2, 0, 4])

That kind of JAX code looks very bad, and as a motivator for writing portable code based on a standard I don't think it's a positive. JAX could add support for x[ix_bool] += 1 tomorrow, and imho they should rather than make users write x.at[cumsum(ix_bool)].add(1). It's 100% equivalent. And for the "expanded boolean indexing" where the desired result is [0, 2, 4, 0, 0] as an output, it needs a new builtin function like jnp.bool_index(xs, conds). It has little to do with cumulative sums for data/statistical purposes, it's a type of indexing.

There's a known blocker (brought up by the JAX team before) for adding more in-place operator support beyond what they have now, which is that this standard should be able to better guarantee that += doesn't modify views in NumPy et al. But if we can make that guarantee at some point, then we have what we want here.

soraros

soraros commented on Feb 16, 2023

@soraros
Contributor

@soraros it seems like all that usage of cumsum in JAX is just a cumbersome way of working around not having boolean indexing support for in-place ops?

@rgommers Yes and no. It is a cumbersome workaround, but I also think the problem is more fundamental than that. JAX is essentially a front-end for XLA, and the primitives provided by XLA (for now) require static shape. So the line that actually go wrong is

>>> xs[ix_bool]
array([0, 2, 4])

Note this code does work in JAX, though not jittable, for we don't know its output shape. Let's pretend x[ix_bool] += 1 is syntax sugar for x = x + where(ix_bool, 1, 0) (which works in JAX) for a moment. The same problem appears when we want x[ix_bool] += [1, 3, 5]. Again, we somehow need to know the shape of the rhs, which is equivalent to know the shape of xs[ix_bool] as in the last example.

So what we really work around is the static shape requirement (recall the need of a size parameter for nonzero), which is not exclusively JAX.

Now, for something a bit off-topic.

I think the JAX style functional syntax a = a.at[...].set(...) for in-place operation looks (and arguably works) better than numpy, and I'd really like to have it for array api. Some pros:

  • Looks familiar, and simulates the feel of in-place operation just fine.
  • Made it clear nothing is modified. This restricted access pattern would work with any accelerator-backed system. I think it would aid static analysis in system like Numba as well.
  • More concise, can be chained, and sometimes express our intention better.
a = zeros(m)       # initialing a
a[I] += arange(n)  # semantically, still initialing a

# VS

# being concise here is not the important point
# this line becomes a "semantical block" for initialisation
a = zeros(m).at[I].add(arange(n))  # initialing a
  • Can specify indexing mode, (more) easily.
# I think these are fairly cumbersome to represent in `numpy`, as we don't have kwargs for __getitem__
b = a.at[I].add(val, unique_indices=True)     # important info for accelerators
c = b.at[J].get(mode='fill', fill_value=nan)  # sure, we have `take`, but this is uniform and cool
rgommers

rgommers commented on Feb 17, 2023

@rgommers
Member

@soraros thanks for all this detail, it's very interesting actually. I think there's something to be said indeed for .at - especially if JAX can make a new cleaner API for it that doesn't rely on things like explicit use of cumsum. This discussion is effectively a follow-up to gh-84 (EDIT: and gh-24). What do you think about opening a new issue to continue this discussion, especially your "for something a bit off-topic" part?

soraros

soraros commented on Feb 17, 2023

@soraros
Contributor

@rgommers Glad you find the exchange interesting!
gh-84 is indeed interesting, I will read that throughly later.
I'm not sure how to proceed regarding opening a new issue to continue the discussion though (not sure about the exact topic and/or scope you have in mind). If it's not too much trouble, could you open a new issue so I (or maybe you) can move the relevant part there? Thanks in advance!

asmeurer

asmeurer commented on Feb 17, 2023

@asmeurer
Member

How do you implement cumulative sum using only array API functions (and without using a Python loop)?

oleksandr-pavlyk

oleksandr-pavlyk commented on Feb 17, 2023

@oleksandr-pavlyk
Contributor

One inefficient possibility:

In [4]: def cumsum(x):
   ...:     assert x.ndim == 1
   ...:     n = x.shape[0]
   ...:     return tril(ones((n,n,))) @ x
   ...:

In [5]: cumsum(np.array([1,2,3,4,5]))
Out[5]: array([ 1.,  3.,  6., 10., 15.])
changed the title [-]Add `cumsum` to the standard[/-] [+]Add support for computing the cumulative sum to the standard[/+] on Feb 19, 2023
rgommers

rgommers commented on Mar 9, 2023

@rgommers
Member

I'm not sure how to proceed regarding opening a new issue to continue the discussion though (not sure about the exact topic and/or scope you have in mind). If it's not too much trouble, could you open a new issue so I (or maybe you) can move the relevant part there? Thanks in advance!

Done now in gh-609 - sorry for the delay! I spent a lot of time refreshing my memory and trying to put together something more coherent. But it's tricky; the need for new API in JAX to avoid cumsum seems clear, but if it was easy to exactly define how it'd look, the JAX devs would have done that by now I guess:)

shoyer

shoyer commented on Apr 3, 2023

@shoyer
Contributor

I think cumulative sum is a rather fundamental array operation and we should include it in the standard. None of the typical reasons for omitting a function from the API standard apply here:

  • It has well defined behavior (aside from NumPy's funny flattening behavior if axis is omitted)
  • It is not particularly hard to implement in a distributed fashion or on accelerators, as evidenced by how it can be found in Dask and JAX.
  • It is not easy to implement in terms of other fundamental operations.

To give a few other examples of how I've used it:

  1. Calculating integrals
  2. Calculating moving window averages
  3. Calculating cumulative probabilities

As a reference point on popularity, I see about twice as many uses of np.cumsum in Google's codebase as for np.roll or np.var.

I'll also second Ralf's point on calling it cumulative_sum rather than cumsum.

WarrenWeckesser

WarrenWeckesser commented on Apr 3, 2023

@WarrenWeckesser

It has well defined behavior...

Before locking in the specification, it would be worthwhile taking a look at numpy/numpy#6044, and numpy/numpy#14542 that is referenced from numpy/numpy#6044. My comment in numpy/numpy#14542 provides some evidence for the usefulness of allowing the result to include a prepended 0. (More generally, for other cumulative operations such as cumprod, the identity of the operation would be prepended.)

15 remaining items

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    API extensionAdds new functions or objects to the API.

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

      Development

      Participants

      @seberg@asmeurer@rgommers@WarrenWeckesser@shoyer

      Issue actions

        Add support for computing the cumulative sum to the standard · Issue #597 · data-apis/array-api