-
Notifications
You must be signed in to change notification settings - Fork 126
Remove upsample_bilinear2d lowering and use decomposition #934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b807722
to
f15a3d3
Compare
I think we should follow the data, so I am fine swapping if it is faster. I am curious why though... Is the extra simplifications done by lowerings somehow backfiring and generating slower code? Have you looked at the generated code? Maybe there is some fix that could help other kernels too. |
@jansel I'm not sure to be honest. This is the triton for the "lowering" version: https://pastebin.com/a3UXacPX This is the triton for the "decomposition" version:https://pastebin.com/jrGiQAQY In this case there's a 10% perf improvement from the decomposition vs. the lowering. (Note: Both of these have a pointwise op fused to the end, although my previous benchmarks didn't have any pointwise op). I think they're mostly the same, except that the shapes are static for the "lowering" version, and the ops are rearranged a bit. Maybe the fact that the loads are all bunched together is making perf slightly worse for the lowering version? To be clear, I think the main benefit of this experiment (if it's conclusions are true) are that we can probably mostly reuse decompositions for lowering ops as long as they're using indexing, instead of needing to re-implement them as lowerings. My main question re: benchmarking are whether there's any cases that you think the lowering should handle better that I'm currently not testing. For example, perhaps some horizontal fusion of transposed upsample ops that we might be able to simplify? |
Interesting, my first guess looking at the two kernels is just the ordering of ops. The loads on the faster kernel are "spread out" while the loads in the slow kernel are "bunched up". Perhaps we should explore a compiler pass that reorders ops within a kernel. Our current inductor kernels usually look like:
When you have indirect loads, it moves those indirect loads into the "compute" section, because they must come after the address computation. Thus allowing that spread out pattern to be generated. My thinking of doing that ordering was it makes compiler analysis easier for Triton/LLVM. I may have been wrong there. This is just one theory though, we should test it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's ship this. I created a ticket to look into instruction reordering.
@@ -278,11 +280,6 @@ def scatter_reduce(self, dim: int, index, src, reduction_type, **kwargs): | |||
return self.clone().scatter_reduce_(dim, index, src, reduction_type, **kwargs) | |||
|
|||
|
|||
@register_decomposition([aten.narrow]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to remove this? I don't see a replacement decomp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah these are both decompositions for "CompositeImplicitAutograd" ops that are no longer needed now that we always decompose CompositeImplicitAutograd ops (i.e. Ed's change).
So... there's been some discussion about "decompositions vs. lowerings" (see pytorch/pytorch#93623). In particular, do we sacrifice performance from using indexing-style lowerings as opposed to decomps?
I think a decently representative example is https://github.com/pytorch/torchdynamo/blob/main/torchinductor/lowering.py#L1646.
The decomp is here: https://github.com/pytorch/pytorch/blob/master/torch/_decomp/decompositions.py#L1603
CUDA impl is here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu
So, let's actually benchmark the perf difference. Surprisingly, it seems like the decomp actually tends to be faster! (although not always).
By the way, this is how eager vs. decomp performs (inductor is much faster!)
Given these perf results, and the other benefits of unifying on decompositions instead of lowerings, i propose we delete the lowering and use the decomp instead.
@jansel Anything else you think I should check in my benchmarking?
This is my benchmark code.
Script