Skip to content

Remove upsample_bilinear2d lowering and use decomposition #934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 21, 2022

Conversation

Chillee
Copy link
Contributor

@Chillee Chillee commented Aug 21, 2022

So... there's been some discussion about "decompositions vs. lowerings" (see pytorch/pytorch#93623). In particular, do we sacrifice performance from using indexing-style lowerings as opposed to decomps?

I think a decently representative example is https://github.com/pytorch/torchdynamo/blob/main/torchinductor/lowering.py#L1646.

The decomp is here: https://github.com/pytorch/pytorch/blob/master/torch/_decomp/decompositions.py#L1603

CUDA impl is here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu

So, let's actually benchmark the perf difference. Surprisingly, it seems like the decomp actually tends to be faster! (although not always).

image

By the way, this is how eager vs. decomp performs (inductor is much faster!)

image

Given these perf results, and the other benefits of unifying on decompositions instead of lowerings, i propose we delete the lowering and use the decomp instead.

@jansel Anything else you think I should check in my benchmarking?

This is my benchmark code.

Script
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from functorch.compile import aot_module, nop, print_compile
from torchinductor.compile_fx import compile_fx_inner
import torch.nn as nn
import torchinductor
from functorch import vmap
from torchinductor.decomposition import decompositions
from torchinductor.compile_fx import cudagraphify
import time

B = 128
C = 64

def bench(f, inp):
    iters = 100
    torch.cuda.synchronize()
    begin = time.time()
    for _ in range(iters):
        f(*inp)
    torch.cuda.synchronize()
    return (time.time() - begin)*1e6/iters

def compare(H, W, scaling, align_corners):
    def f(x):
        val = torch.ops.aten.upsample_bilinear2d.vec(x, None, align_corners, scaling)
        return (val,)

    inp = [torch.randn(B, C, H, W, device='cuda')]
    decomposed = make_fx(f, decomposition_table=decompositions, tracing_mode="fake")(*inp)
    non_decomposed = make_fx(f, tracing_mode="fake")(*inp)

    compiled_decomposed = compile_fx_inner(decomposed, inp)
    compiled_nondecomposed = compile_fx_inner(non_decomposed, inp)

    cuda_f = cudagraphify(f, inp)
    orig_time = bench(cuda_f, inp)
    decomp_time = bench(compiled_decomposed, inp)
    nondecomp_time = bench(compiled_nondecomposed, inp)
    print(f"{orig_time:.2f}, {decomp_time:.2f}, {nondecomp_time:.2f}")
    return orig_time, decomp_time, nondecomp_time

f = open('results.txt', 'w')
for H in [128, 256, 512]:
    for W in [128, 256, 512]:
        for scaling in ([0.5, 0.5], [0.5, 1.0], [1.5, 1.5], [2.0, 1.0]):
            for align_corners in [True, False]:
                print(f"({H}, {W}), {scaling}, {align_corners}")
                f.write(f"{H},{W}, {scaling[0]}, {scaling[1]}, {align_corners}\n")
                res = compare(H, W, scaling, align_corners)
                f.write(f"{res[0]},{res[1]},{res[2]}\n")
                f.flush()

@jansel
Copy link
Contributor

jansel commented Aug 21, 2022

I think we should follow the data, so I am fine swapping if it is faster.

I am curious why though... Is the extra simplifications done by lowerings somehow backfiring and generating slower code? Have you looked at the generated code? Maybe there is some fix that could help other kernels too.

@Chillee
Copy link
Contributor Author

Chillee commented Aug 21, 2022

@jansel I'm not sure to be honest.

This is the triton for the "lowering" version: https://pastebin.com/a3UXacPX

This is the triton for the "decomposition" version:https://pastebin.com/jrGiQAQY

In this case there's a 10% perf improvement from the decomposition vs. the lowering.

(Note: Both of these have a pointwise op fused to the end, although my previous benchmarks didn't have any pointwise op).

I think they're mostly the same, except that the shapes are static for the "lowering" version, and the ops are rearranged a bit.

Maybe the fact that the loads are all bunched together is making perf slightly worse for the lowering version?

To be clear, I think the main benefit of this experiment (if it's conclusions are true) are that we can probably mostly reuse decompositions for lowering ops as long as they're using indexing, instead of needing to re-implement them as lowerings.

My main question re: benchmarking are whether there's any cases that you think the lowering should handle better that I'm currently not testing. For example, perhaps some horizontal fusion of transposed upsample ops that we might be able to simplify?

@jansel
Copy link
Contributor

jansel commented Aug 21, 2022

Interesting, my first guess looking at the two kernels is just the ordering of ops. The loads on the faster kernel are "spread out" while the loads in the slow kernel are "bunched up".

Perhaps we should explore a compiler pass that reorders ops within a kernel.

Our current inductor kernels usually look like:

<all of the loads>
<all of the compute>
<all of the stores>

When you have indirect loads, it moves those indirect loads into the "compute" section, because they must come after the address computation. Thus allowing that spread out pattern to be generated.

My thinking of doing that ordering was it makes compiler analysis easier for Triton/LLVM. I may have been wrong there.

This is just one theory though, we should test it.

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's ship this. I created a ticket to look into instruction reordering.

@@ -278,11 +280,6 @@ def scatter_reduce(self, dim: int, index, src, reduction_type, **kwargs):
return self.clone().scatter_reduce_(dim, index, src, reduction_type, **kwargs)


@register_decomposition([aten.narrow])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to remove this? I don't see a replacement decomp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah these are both decompositions for "CompositeImplicitAutograd" ops that are no longer needed now that we always decompose CompositeImplicitAutograd ops (i.e. Ed's change).

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L3749

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L5411

@Chillee Chillee merged commit e062891 into main Aug 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants