-
Notifications
You must be signed in to change notification settings - Fork 126
[inductor] Lower aten.index_add_/aten.index_add #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I looked into this a bit, I'm not sure it has that much in common with For e.g. I can't think of a lowering of |
You should be able to map this to
If you set
In a thread-safe way. The iteration ranges of scatter then become the iteration space of the scatter loops (which will be different than the size of the output). The lowering should be something like this (warning, this is untested): def index_add_(self, dim, index, source, *, alpha=1):
index_loader = index.make_loader()
source_loader = source.make_loader()
def output_indexer(index):
index = list(index)
index[dim] = ops.indirect_indexing(index_loader([index[dim]]))
return index
def fn(index):
return ops.mul(
ops.constant(alpha, source.get_dtype()),
source_loader(index)
)
scatter = ir.Scatter(
device=self.get_device(),
dtype=self.get_dtype(),
inner_fn=fn,
ranges=list(source.get_size()),
output_indexer=output_indexer,
scatter_mode="atomic_add"
)
buffer = ir.ComputedBuffer(
None,
ir.MutationLayout(self),
scatter,
)
buffer.name = V.graph.register_buffer(buffer)
return self |
For completeness, we should also add |
@eellison |
This and |
The decompositions are here: pytorch/pytorch#85002. @eellison if you implement the lowering, you can benchmark it against the decomposition see how they fare. |
Of course, if you guys think that the decomposition is more convenient than the lowering, I could implement the decomposition for |
Should be fairly similar to how
aten.index_put_/aten.index_put
is currently handled, seetorchdynamo/torchinductor/lowering.py
Line 1445 in 7cc850c
torchdynamo/torchinductor/decomposition.py
Line 357 in 7cc850c
The text was updated successfully, but these errors were encountered: