Skip to content

[Inductor] [CPU] Vectorization not supporting python pass-in scalar double in speech_transformer #93446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Valentine233 opened this issue Nov 23, 2022 · 0 comments
Assignees
Labels
triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Valentine233
Copy link
Collaborator

Description

Comparing performances of speech_transformer with backends inductor and IPEX, inductor is 0.68 IPEX. The main reason is that vectorization does not support python pass-in scalar double.

Profiling and Code snippet

image

kernel_cpp_8 = async_compile.cpp('''
#include "/tmp/tmp8ofgbidl/rp/crpdeql3xwpfmcyakwtqpzihz525if6mt25mozau77xvmnh7vqyu.h"
extern "C" void kernel(float* __restrict__ in_out_ptr0,
                       const bool* __restrict__ in_ptr0,
                       const double* __restrict__ in_ptr2,
                       const float* __restrict__ in_ptr3,
                       float* __restrict__ out_ptr0,
                       float* __restrict__ out_ptr2,
                       float* __restrict__ out_ptr4)
{
    auto in_ptr1 = in_out_ptr0;
    auto out_ptr1 = in_out_ptr0;
    auto out_ptr3 = in_out_ptr0;
    #pragma omp parallel num_threads(28)
    {
        #pragma omp for 
        for(long i0=0; i0<80; i0+=1)
        {
            #pragma GCC ivdep
            for(long i1=0; i1<204; i1+=1)
            {
                {
                    {
                        float tmp7 = -std::numeric_limits<float>::infinity();
                        for(long i2=0; i2<204; i2+=1)
                        {
                            {
                                auto tmp0 = in_ptr0[i2 + (204*(i0 % 10))];
                                auto tmp2 = in_ptr1[i2 + (204*i1) + (41616*i0)];
                                auto tmp3 = in_ptr2[0];
                                auto tmp1 = -std::numeric_limits<float>::infinity();
                                auto tmp4 = static_cast<float>(tmp3);
                                auto tmp5 = tmp2 / tmp4;
                                auto tmp6 = tmp0 ? tmp1 : tmp5;
                                tmp7 = std::max(tmp7, tmp6);
                            }
                        }
                        out_ptr0[i1 + (204*i0)] = tmp7;
                    }
                }
            }
        }
        #pragma omp for 
        for(long i0=0; i0<80; i0+=1)
        {
            #pragma GCC ivdep
            for(long i1=0; i1<204; i1+=1)
            {
                #pragma GCC ivdep
                for(long i2=0; i2<204; i2+=1)
                {
                    {
                        {
                            auto tmp0 = in_ptr0[i2 + (204*(i0 % 10))];
                            auto tmp2 = in_ptr1[i2 + (204*i1) + (41616*i0)];
                            auto tmp3 = in_ptr2[0];
                            auto tmp7 = out_ptr0[i1 + (204*i0)];
                            auto tmp1 = -std::numeric_limits<float>::infinity();
                            auto tmp4 = static_cast<float>(tmp3);
                            auto tmp5 = tmp2 / tmp4;
                            auto tmp6 = tmp0 ? tmp1 : tmp5;
                            auto tmp8 = tmp6 - tmp7;
                            auto tmp9 = std::exp(tmp8);
                            out_ptr1[i2 + (204*i1) + (41616*i0)] = tmp9;
                        }
                    }
                }
            }
        }
        #pragma omp for 
        for(long i0=0; i0<16320; i0+=1)
        {
            {
                #pragma omp declare reduction(+:at::vec::Vectorized<float>:omp_out += omp_in) initializer(omp_priv={{0}})
                float tmp1 = 0;
                auto tmp1_vec = at::vec::Vectorized<float>(tmp1);
                for(long i1=0; i1<12; i1+=1)
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + (16*i1) + (204*i0));
                    tmp1_vec += tmp0;
                }
                tmp1 = at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>&y) {return x + y;}, tmp1_vec);
                #pragma omp simd simdlen(8)  reduction(+:tmp1)
                for(long i1=192; i1<204; i1+=1)
                {
                    auto tmp0 = out_ptr1[i1 + (204*i0)];
                    tmp1 += tmp0;
                }
                out_ptr2[i0] = tmp1;
            }
        }
        #pragma omp for 
        for(long i0=0; i0<16320; i0+=1)
        {
            for(long i1=0; i1<12; i1+=1)
            {
                auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr1 + (16*i1) + (204*i0));
                auto tmp1 = at::vec::Vectorized<float>(out_ptr2[i0]);
                auto tmp2 = tmp0 / tmp1;
                tmp2.store(out_ptr3 + (16*i1) + (204*i0));
            }
            #pragma omp simd simdlen(8) 
            for(long i1=192; i1<204; i1+=1)
            {
                auto tmp0 = out_ptr1[i1 + (204*i0)];
                auto tmp1 = out_ptr2[i0];
                auto tmp2 = tmp0 / tmp1;
                out_ptr3[i1 + (204*i0)] = tmp2;
            }
        }
        #pragma omp for  collapse(2)
        for(long i0=0; i0<8; i0+=1)
        {
            for(long i1=0; i1<2040; i1+=1)
            {
                for(long i2=0; i2<4; i2+=1)
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr3 + (16*i2) + (64*i0) + (512*i1));
                    tmp0.store(out_ptr4 + (16*i2) + (64*i1) + (130560*i0));
                }
                #pragma omp simd simdlen(8) 
                for(long i2=64; i2<64; i2+=1)
                {
                    auto tmp0 = in_ptr3[i2 + (64*i0) + (512*i1)];
                    out_ptr4[i2 + (64*i1) + (130560*i0)] = tmp0;
                }
            }
        }
    }
}
''')

According to the profiling analysis, bottlenecks are kernel_cpp_8, kernel_cpp_14, kernel_cpp_20, kernel_cpp_2, kernel_cpp_32 and kernel_cpp_26, which are the implementations for the same Python code snippet:

attn = attn / self.temperature
attn = attn.masked_fill(mask, -np.inf)
code: attn = self.softmax(attn)

As self.temperature in Python, a.k.a. __restrict__ in_ptr2 in C++, is a double scalar, vectorization is not applied.

Minified repro

python benchmarks/dynamo/torchbench.py --performance --float32 -dcpu -n50 --inductor --no-skip --dashboard -k "speech_transformer" --cold_start_latency --channels-last

cc @EikanWang @jgong5

@SherlockNoMad SherlockNoMad added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 28, 2022
@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: TODO
Development

No branches or pull requests

2 participants