Skip to content

Commit aeaeba4

Browse files
[release/2.2] Fuse gate_proj and up_proj in MLP of LLaMA (#2469)
* Fuse gate_proj and up_proj in MLP of LLaMA (#2430) * Fuse gate_proj and up_proj in MLP of LLaMA * fix clang-format * Update run_quantization.py (#2471) --------- Co-authored-by: jianan-gu <[email protected]>
1 parent de99dd7 commit aeaeba4

File tree

7 files changed

+265
-8
lines changed

7 files changed

+265
-8
lines changed

csrc/cpu/aten/TPPGEMM.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace cpu {
99
IPEX_DEFINE_DISPATCH(tpp_linear_nobias_kernel_stub);
1010
IPEX_DEFINE_DISPATCH(tpp_linear_bias_kernel_stub);
1111
IPEX_DEFINE_DISPATCH(tpp_linear_gelu_kernel_stub);
12+
IPEX_DEFINE_DISPATCH(tpp_fused_gate_up_proj_kernel_stub);
1213
IPEX_DEFINE_DISPATCH(tpp_linear_silu_kernel_stub);
1314
IPEX_DEFINE_DISPATCH(tpp_linear_relu_kernel_stub);
1415
IPEX_DEFINE_DISPATCH(tpp_linear_add_kernel_stub);
@@ -38,6 +39,17 @@ at::Tensor tpp_linear_gelu_forward_cpu(
3839
return tpp_linear_gelu_kernel_stub(kCPU, t_in, t_wt, t_bias);
3940
}
4041

42+
at::Tensor tpp_fused_gate_up_proj_forward_cpu(
43+
const at::Tensor& t_in,
44+
const at::Tensor& t_wt_gate,
45+
const at::Tensor& t_bias_gate,
46+
const at::Tensor& t_wt_up,
47+
const at::Tensor& t_bias_up,
48+
c10::optional<int64_t> out_features) {
49+
return tpp_fused_gate_up_proj_kernel_stub(
50+
kCPU, t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up);
51+
}
52+
4153
at::Tensor tpp_linear_silu_forward_cpu(
4254
const at::Tensor& t_in,
4355
const at::Tensor& t_wt,
@@ -117,6 +129,15 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
117129
torch_ipex::cpu::tpp_linear_gelu_forward_cpu);
118130
}
119131

132+
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
133+
m.def(
134+
"tpp_fused_gate_up_proj(Tensor t_in, Tensor t_wt_gate, Tensor t_bias_gate, Tensor t_wt_up, Tensor t_bias_up,int? out_features=None)-> Tensor out");
135+
m.impl(
136+
"tpp_fused_gate_up_proj",
137+
c10::DispatchKey::CPU,
138+
torch_ipex::cpu::tpp_fused_gate_up_proj_forward_cpu);
139+
}
140+
120141
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
121142
m.def(
122143
"tpp_linear_add_add(Tensor t_in, Tensor t_in1, Tensor t_in2, Tensor t_wt, Tensor t_bias, float scale, int? out_features=None)-> Tensor out");

csrc/cpu/aten/TPPGEMM.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ at::Tensor tpp_linear_gelu_forward_cpu(
2424
const at::Tensor& t_bias,
2525
c10::optional<int64_t> out_features);
2626

27+
at::Tensor tpp_fused_gate_up_proj_forward_cpu(
28+
const at::Tensor& t_in,
29+
const at::Tensor& t_wt_gate,
30+
const at::Tensor& t_bias_gate,
31+
const at::Tensor& t_wt_up,
32+
const at::Tensor& t_bias_up,
33+
c10::optional<int64_t> out_features);
34+
2735
at::Tensor tpp_linear_silu_forward_cpu(
2836
const at::Tensor& t_in,
2937
const at::Tensor& t_wt,
@@ -71,6 +79,13 @@ using tpp_linear_bias_kernel_impl_fn =
7179
using tpp_linear_gelu_kernel_impl_fn =
7280
at::Tensor (*)(const at::Tensor&, const at::Tensor&, const at::Tensor&);
7381

82+
using tpp_fused_gate_up_proj_kernel_impl_fn = at::Tensor (*)(
83+
const at::Tensor&,
84+
const at::Tensor&,
85+
const at::Tensor&,
86+
const at::Tensor&,
87+
const at::Tensor&);
88+
7489
using tpp_linear_silu_kernel_impl_fn =
7590
at::Tensor (*)(const at::Tensor&, const at::Tensor&, const at::Tensor&);
7691

@@ -105,6 +120,9 @@ IPEX_DECLARE_DISPATCH(
105120
IPEX_DECLARE_DISPATCH(
106121
tpp_linear_gelu_kernel_impl_fn,
107122
tpp_linear_gelu_kernel_stub);
123+
IPEX_DECLARE_DISPATCH(
124+
tpp_fused_gate_up_proj_kernel_impl_fn,
125+
tpp_fused_gate_up_proj_kernel_stub);
108126
IPEX_DECLARE_DISPATCH(
109127
tpp_linear_silu_kernel_impl_fn,
110128
tpp_linear_silu_kernel_stub);

csrc/cpu/aten/kernels/TPPGEMMKrnl.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,38 @@ at::Tensor tpp_linear_gelu_kernel_impl(
8787
return t_out;
8888
}
8989

90+
at::Tensor tpp_fused_gate_up_proj_kernel_impl(
91+
const at::Tensor& t_in,
92+
const at::Tensor& t_wt_gate,
93+
const at::Tensor& t_bias_gate,
94+
const at::Tensor& t_wt_up,
95+
const at::Tensor& t_bias_up) {
96+
auto sizes = t_in.sizes().vec();
97+
AT_ASSERT(
98+
t_wt_gate.sizes() == t_wt_up.sizes(),
99+
"Expect t_wt_gate.sizes() == t_wt_up.sizes()");
100+
auto wt_sizes = t_wt_gate.sizes();
101+
sizes[2] = wt_sizes[0] * wt_sizes[3];
102+
103+
auto t_out = t_in.new_empty(sizes);
104+
105+
auto dt = t_wt_gate.dtype();
106+
if (dt == at::kFloat) {
107+
torch_ipex::tpp::tpp_fused_gate_up_proj<float>(
108+
t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up, t_out);
109+
} else if (dt == at::kBFloat16) {
110+
torch_ipex::tpp::tpp_fused_gate_up_proj<at::BFloat16>(
111+
t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up, t_out);
112+
} else {
113+
AT_ASSERT(
114+
0,
115+
"TPP does not support current weight dtype %s:%d\n",
116+
__FILE__,
117+
__LINE__);
118+
}
119+
return t_out;
120+
}
121+
90122
at::Tensor tpp_linear_silu_kernel_impl(
91123
const at::Tensor& t_in,
92124
const at::Tensor& t_wt,
@@ -219,6 +251,9 @@ IPEX_REGISTER_DISPATCH(
219251
IPEX_REGISTER_DISPATCH(
220252
tpp_linear_gelu_kernel_stub,
221253
&tpp_linear_gelu_kernel_impl);
254+
IPEX_REGISTER_DISPATCH(
255+
tpp_fused_gate_up_proj_kernel_stub,
256+
&tpp_fused_gate_up_proj_kernel_impl);
222257
IPEX_REGISTER_DISPATCH(
223258
tpp_linear_relu_kernel_stub,
224259
&tpp_linear_relu_kernel_impl);

csrc/cpu/tpp/kernels/TPPGEMMKrnl.h

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ REGISTER_LOCAL_SCOPE(
4242
REGISTER_LOCAL_SCOPE(
4343
tpp_linear_silu_krnl,
4444
"tpp_linear_silu_krnl"); // linear bias + silu
45+
REGISTER_LOCAL_SCOPE(
46+
tpp_fused_gate_up_proj_krnl,
47+
"tpp_fused_gate_up_proj_krnl"); // fused gate_proj and up_proj
4548
REGISTER_LOCAL_SCOPE(
4649
tpp_linear_relu_krnl,
4750
"tpp_linear_relu_krnl"); // linear bias + relu
@@ -521,6 +524,141 @@ inline void tpp_linear_gelu(
521524
}
522525
}
523526

527+
// Fused kernel for the gate_proj and the up_proj related computation in the MLP
528+
// of LLaMA. The ref computation of the kernel is:
529+
// act_fn(gate_proj(x)) * up_proj(x) where act_fn is silu, gate_proj and
530+
// up_proj are two nn.Linear with the same weight shapes and bias = False.
531+
// t_in is the input activation
532+
// t_wt_gate is the prepacked weight of the gate_proj
533+
// t_wt_up is the prepacked weight of the up_proj
534+
// t_bias_gate is the bias of the gate_proj
535+
// t_bias_up is the bias of the up_proj
536+
// t_out is the output result of the kernel
537+
template <typename T>
538+
inline void tpp_fused_gate_up_proj(
539+
const at::Tensor& t_in,
540+
const at::Tensor& t_wt_gate,
541+
const at::Tensor& t_bias_gate,
542+
const at::Tensor& t_wt_up,
543+
const at::Tensor& t_bias_up,
544+
at::Tensor& t_out) {
545+
auto t_wt_gate_ = t_wt_gate;
546+
auto t_wt_up_ = t_wt_up;
547+
auto in_sizes = t_in.sizes();
548+
auto BS = in_sizes[0] * in_sizes[1];
549+
if (BS > FT_OPT_SIZE) { // first token compute
550+
t_wt_gate_ = wt_tensor_for_first_token<T>(t_wt_gate_);
551+
t_wt_up_ = wt_tensor_for_first_token<T>(t_wt_up_);
552+
large_cache_opt = true;
553+
}
554+
555+
auto wt_sizes = t_wt_gate_.sizes();
556+
auto C = in_sizes[2];
557+
558+
auto Nc = wt_sizes[1];
559+
auto Hc = C / Nc;
560+
auto Nk = wt_sizes[0];
561+
auto Hk = wt_sizes[3];
562+
auto K = Nk * Hk;
563+
564+
auto t_wt_gate_V =
565+
torch_ipex::tpp::wt_tensor_for_fwd(Nk, Hk, Nc, Hc, t_wt_gate_);
566+
auto t_wt_up_V = torch_ipex::tpp::wt_tensor_for_fwd(Nk, Hk, Nc, Hc, t_wt_up_);
567+
568+
// This is used to store the intermediate result of the up_proj layer
569+
auto t_out_tmp = at::empty_like(t_out);
570+
571+
auto in = GetVLAPtr<T>(t_in, {Nc, Hc});
572+
auto wt_gate_V = GetVLAPtr<T>(t_wt_gate_V, {Nc, Hc * Hk});
573+
auto wt_up_V = GetVLAPtr<T>(t_wt_up_V, {Nc, Hc * Hk});
574+
auto bias_gate = GetVLAPtr<T>(t_bias_gate, {Hk});
575+
auto bias_up = GetVLAPtr<T>(t_bias_up, {Hk});
576+
auto out = GetVLAPtr<T>(t_out, {Nk, Hk});
577+
auto out_tmp = GetVLAPtr<T>(t_out_tmp, {Nk, Hk});
578+
579+
auto Ncb = Nc;
580+
auto BSb = 64L;
581+
auto rem = BS % 64;
582+
if (large_cache_opt)
583+
Ncb = NCB_BLOCK_SIZE;
584+
585+
bool with_bias_gate = (t_bias_gate.numel() > 0);
586+
bool with_bias_up = (t_bias_up.numel() > 0);
587+
auto copy_bias_tpp = SCOPEIT(CpyBiasTPP<T>(BSb, Hk, K), BIAS);
588+
auto copy_bias_tpp_rem = SCOPEIT(CpyBiasTPP<T>(rem, Hk, K), BIAS);
589+
auto zero_tpp = SCOPEIT(SetZeroTPP<T>(BSb, Hk, K), EW_ZERO);
590+
auto zero_tpp_rem = SCOPEIT(SetZeroTPP<T>(rem, Hk, K), EW_ZERO);
591+
auto brgemm_tpp = SCOPEITGEMM(
592+
(BrgemmTPP<T, T>(BSb, Hk, Hc, Hc, Hk * Hc, C, Hk, K, 1.0, 0, Ncb)));
593+
auto brgemm_tpp_rem = SCOPEITGEMM(
594+
(BrgemmTPP<T, T>(rem, Hk, Hc, Hc, Hk * Hc, C, Hk, K, 1.0, 0, Ncb)));
595+
auto silu_fwd_tpp = SCOPEIT(SiLUFwdTPP<T>(BSb, Hk, K, K), ACT);
596+
auto silu_fwd_tpp_rem = SCOPEIT(SiLUFwdTPP<T>(rem, Hk, K, K), ACT);
597+
auto mul_tpp = SCOPEIT((MulTPP<T, T>(BSb, Hk, K, K)), EW_MUL);
598+
auto mul_tpp_rem = SCOPEIT((MulTPP<T, T>(rem, Hk, K, K)), EW_MUL);
599+
600+
{
601+
RECORD_SCOPE(tpp_fused_gate_up_proj_krnl, {t_in, t_wt_gate_V});
602+
603+
auto loop_scheme = large_cache_opt ? GEMM_LOOP_SCHEME : "aCb";
604+
auto igemm_loop = torch_ipex::tpp::ThreadedLoop<3>(
605+
{{0, Nc, Ncb, false}, {0, BS, BSb}, {Nk}}, loop_scheme);
606+
igemm_loop(
607+
[&](int* ind) {
608+
int nc = ind[0], s1 = ind[1], nk = ind[2];
609+
auto count = nc + Ncb < Nc ? Ncb : Nc - nc;
610+
bool is_rem = (s1 + BSb > BS);
611+
if (!is_rem) {
612+
if (nc == 0) {
613+
if (with_bias_gate) {
614+
copy_bias_tpp(bias_gate[nk], out[s1][nk]);
615+
} else {
616+
zero_tpp(out[s1][nk]);
617+
}
618+
619+
if (with_bias_up) {
620+
copy_bias_tpp(bias_up[nk], out_tmp[s1][nk]);
621+
} else {
622+
zero_tpp(out_tmp[s1][nk]);
623+
}
624+
}
625+
brgemm_tpp(in[s1][nc], wt_gate_V[nk][nc], out[s1][nk], count, true);
626+
brgemm_tpp(
627+
in[s1][nc], wt_up_V[nk][nc], out_tmp[s1][nk], count, true);
628+
if (!(nc + Ncb < Nc)) { // last nc iter
629+
silu_fwd_tpp(out[s1][nk], out[s1][nk]);
630+
mul_tpp(out[s1][nk], out_tmp[s1][nk], out[s1][nk]);
631+
}
632+
} else {
633+
if (nc == 0) {
634+
if (with_bias_gate) {
635+
copy_bias_tpp_rem(bias_gate[nk], out[s1][nk]);
636+
} else {
637+
zero_tpp_rem(out[s1][nk]);
638+
}
639+
640+
if (with_bias_up) {
641+
copy_bias_tpp_rem(bias_up[nk], out_tmp[s1][nk]);
642+
} else {
643+
zero_tpp_rem(out_tmp[s1][nk]);
644+
}
645+
}
646+
brgemm_tpp_rem(
647+
in[s1][nc], wt_gate_V[nk][nc], out[s1][nk], count, false);
648+
brgemm_tpp_rem(
649+
in[s1][nc], wt_up_V[nk][nc], out_tmp[s1][nk], count, false);
650+
brgemm_tpp.config();
651+
if (!(nc + Ncb < Nc)) { // last nc iter
652+
silu_fwd_tpp_rem(out[s1][nk], out[s1][nk]);
653+
mul_tpp_rem(out[s1][nk], out_tmp[s1][nk], out[s1][nk]);
654+
}
655+
}
656+
},
657+
[&]() { brgemm_tpp.config(); },
658+
[&]() { brgemm_tpp.release(); });
659+
}
660+
}
661+
524662
template <typename T>
525663
inline void tpp_linear_add(
526664
const at::Tensor t_in,

examples/cpu/inference/python/llm/single_instance/run_quantization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,11 @@ def calib_func(prepared_model):
536536
op_type_dict=op_type_dict,
537537
smoothquant_args=smoothquant_args
538538
)
539+
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
539540
prepared_model.save_qconf_summary(args.output_dir + "/best_configure.json")
540541

541542
else:
542-
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=args.alpha)
543+
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=float(args.alpha))
543544
user_model = ipex.llm.optimize(
544545
user_model.eval(),
545546
dtype=amp_dtype,

intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -379,22 +379,16 @@ def forward(self, x):
379379
and not self.linear_m.tpp_fallback
380380
):
381381
x = x.to(self.dtype).contiguous()
382-
x1 = torch.ops.torch_ipex.tpp_linear_silu(
382+
return torch.ops.torch_ipex.tpp_fused_gate_up_proj(
383383
x,
384384
self.linear_s.weight.detach(),
385385
self.linear_s.bias.detach()
386386
if self.linear_s.bias is not None
387387
else x.new_empty(0),
388-
self.linear_s.out_features,
389-
)
390-
return torch.ops.torch_ipex.tpp_linear_mul(
391-
x,
392-
x1,
393388
self.linear_m.weight.detach(),
394389
self.linear_m.bias.detach()
395390
if self.linear_m.bias is not None
396391
else x.new_empty(0),
397-
self.linear_m.out_features,
398392
)
399393
else: # fallback path
400394
return nn.functional.silu(self.linear_s(x)) * self.linear_m(x)

tests/cpu/test_tpp_linear.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def forward(self, x):
4646
return torch.nn.functional.silu(self.mlp(x))
4747

4848

49+
class Linear_Gate_Up(torch.nn.Module):
50+
def __init__(self, in_feature, out_feature, bias_gate, bias_up):
51+
super(Linear_Gate_Up, self).__init__()
52+
self.gate_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_gate)
53+
self.up_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_up)
54+
55+
def forward(self, x):
56+
return torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)
57+
58+
4959
class Linear_relu(torch.nn.Module):
5060
def __init__(self):
5161
super(Linear_relu, self).__init__()
@@ -172,6 +182,46 @@ def test_tpp_linear_torchcompile(self):
172182
self.assertTrue(out.dtype == dtype)
173183
_disable_tpp()
174184

185+
def test_tpp_fused_gate_up_proj(self):
186+
in_feature = 64
187+
out_feature = 32
188+
189+
x = torch.randn(1, 4, in_feature)
190+
x_tpp = copy.deepcopy(x)
191+
192+
with torch.no_grad():
193+
for dtype, bias_gate, bias_up in itertools.product(
194+
[torch.float, torch.bfloat16], [False, True], [False, True]
195+
):
196+
model = Linear_Gate_Up(
197+
in_feature, out_feature, bias_gate, bias_up
198+
).eval()
199+
if dtype == torch.bfloat16:
200+
x = x.to(torch.bfloat16)
201+
x_tpp = x_tpp.to(torch.bfloat16)
202+
model = model.to(torch.bfloat16)
203+
ref_out = model(x)
204+
205+
_enable_tpp()
206+
model = ipex.optimize(model, dtype=dtype)
207+
out = torch.ops.torch_ipex.tpp_fused_gate_up_proj(
208+
x_tpp,
209+
model.gate_proj.weight,
210+
model.gate_proj.bias,
211+
model.up_proj.weight,
212+
model.up_proj.bias,
213+
)
214+
215+
out_linear_silu = torch.ops.torch_ipex.tpp_linear_silu(
216+
x_tpp, model.gate_proj.weight, model.gate_proj.bias
217+
)
218+
out_tpp_ref = torch.ops.torch_ipex.tpp_linear_mul(
219+
x_tpp, out_linear_silu, model.up_proj.weight, model.up_proj.bias
220+
)
221+
self.assertEqual(out, out_tpp_ref)
222+
self.assertEqual(out, ref_out)
223+
_disable_tpp()
224+
175225
def test_tpp_linear_gelu(self):
176226
x1 = torch.rand(1, 4, 4096)
177227
x2 = copy.deepcopy(x1)

0 commit comments

Comments
 (0)