Skip to content

[inductor] Lower aten.index_add_/aten.index_add #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Tracked by #93757
desertfire opened this issue Aug 18, 2022 · 7 comments · Fixed by #1292
Closed
Tracked by #93757

[inductor] Lower aten.index_add_/aten.index_add #885

desertfire opened this issue Aug 18, 2022 · 7 comments · Fixed by #1292

Comments

@desertfire
Copy link
Contributor

desertfire commented Aug 18, 2022

Should be fairly similar to how aten.index_put_/aten.index_put is currently handled, see

@desertfire desertfire changed the title aten.index_add_ [inductor] Lower aten.index_add_ Aug 18, 2022
@desertfire desertfire changed the title [inductor] Lower aten.index_add_ [inductor] Lower aten.index_add_/aten.index_add Aug 18, 2022
@eellison eellison self-assigned this Aug 19, 2022
@eellison
Copy link
Contributor

I looked into this a bit, I'm not sure it has that much in common with index_put, since index_add can accumulate arbitrarily many indices of the same value. Also, the usage seems to be pretty different -
for fastNLP_Bert you have inputs: Operator: aten.index_put_.default cnt: 1, ((T([6, 476], i64), [T([6], i64), T([6], i64)], T([], i64)), {}) (aka indices of length 6).

For e.g. hf_Longformer you have the following: ((T([2359296], f16), 0, T([4718592], i64), T([4718592], f16)), {}) (aka, indices of length 4718592).

I can't think of a lowering of index_add within inductor that makes a ton of sense at least within the current IR. Maybe a custom template op (similar to matmul/conv) might make sense ? Or maybe there's a way to generalize this pattern with embedding_bag. Both are pretty common in huggingface.

cc @ngimel @jansel

@jansel
Copy link
Contributor

jansel commented Aug 24, 2022

You should be able to map this to ir.Scatter(..., scatter_mode="atomic_add").

ir.Scatter is very similar to ir.Pointwise, except:

ir.Pointwise computes

out[x] = inner_fn(x)

ir.Scatter computes

out[output_indexer(x)] = inner_fn(x)

If you set scatter_mode="atomic_add" it instead does:

out[output_indexer(x)] += inner_fn(x)

In a thread-safe way.

The iteration ranges of scatter then become the iteration space of the scatter loops (which will be different than the size of the output).

The lowering should be something like this (warning, this is untested):

def index_add_(self, dim, index, source, *, alpha=1):
    index_loader = index.make_loader()
    source_loader = source.make_loader()

    def output_indexer(index):
        index = list(index)
        index[dim] = ops.indirect_indexing(index_loader([index[dim]]))
        return index

    def fn(index):
        return ops.mul(
            ops.constant(alpha, source.get_dtype()), 
            source_loader(index)
        )
    
    scatter = ir.Scatter(
        device=self.get_device(),
        dtype=self.get_dtype(),
        inner_fn=fn,
        ranges=list(source.get_size()),
        output_indexer=output_indexer,
        scatter_mode="atomic_add"
    )
    buffer = ir.ComputedBuffer(
        None,
        ir.MutationLayout(self),
        scatter,
    )
    buffer.name = V.graph.register_buffer(buffer)
    return self

@ngimel
Copy link

ngimel commented Aug 24, 2022

For completeness, we should also add index_select (which can be done as Pointwise, similar toindex)

@desertfire desertfire assigned lezcano and unassigned eellison Sep 13, 2022
@ngimel
Copy link

ngimel commented Sep 13, 2022

@eellison index_put also accumulates arbitrarily many indices into the same location

@lezcano
Copy link
Contributor

lezcano commented Sep 14, 2022

This and index_copy can be done as a decomposition in terms of index_put doing something like x[:, :, ..., idx] = tensor. Writing a lowering should also be easy, as it's pretty much the same as index_select which also has a lowering. I was going to implement the two and see how they compare, but a priori I think they should be comparable, so the decomposition should be a better option.

@lezcano
Copy link
Contributor

lezcano commented Sep 14, 2022

The decompositions are here: pytorch/pytorch#85002.

@eellison if you implement the lowering, you can benchmark it against the decomposition see how they fare.

@lezcano
Copy link
Contributor

lezcano commented Sep 14, 2022

Of course, if you guys think that the decomposition is more convenient than the lowering, I could implement the decomposition for index_select and remove the lowering.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants