Skip to content

Commit 0f27c26

Browse files
liangan1XiaobingSuperEikanWang
authored
concat linear (#278)
* Enable concat linear with same input * Add UT for concat linear * Remove unused code * Add UT for concat linear * Remove unused code * Fix clang-format issue * Fix clang-format issue * Add BF16 UT for concat linear * Remove print code in UT * Remove "torch_ipex::ipex_linear" pass in concat linear pass. 1) torch_ipex::ipex_linear may use block format weight, if so, concat weith with dim=0 may crash. 2) pls use level="00" for ipex.optimize which will pre-pack block format linear with "O1", if you want to enable concat linear. * Fix linear+add UT fail * add concat linear path for ipex linear * remove debug info * keep linear ops type after concat * fix clang-format issue * fix code format issue Co-authored-by: XiaobingSuper <[email protected]> Co-authored-by: Wang Weihan <[email protected]>
1 parent f63ed97 commit 0f27c26

File tree

6 files changed

+423
-9
lines changed

6 files changed

+423
-9
lines changed

intel_extension_for_pytorch/frontend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,4 +237,4 @@ def enable_onednn_fusion(enabled):
237237
if enabled:
238238
core.enable_jit_opt()
239239
else:
240-
core.disable_jit_opt()
240+
core.disable_jit_opt()

tests/cpu/test_jit.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ def __init__(self, in_channels, out_channels, **kwargs):
289289
self.linear1 = nn.Linear(in_channels, out_channels, **kwargs)
290290

291291
def forward(self, x):
292-
return torch.add(self.linear(x),self.linear1(x))
292+
x1 = x.clone()
293+
return torch.add(self.linear(x),self.linear1(x1))
293294

294295
class Linear_Reshape_Relu(nn.Module):
295296
def __init__(self, in_channels, out_channels,dest_shape, **kwargs):
@@ -523,6 +524,21 @@ def forward(self, x, y, z):
523524
x = x + y + z
524525
return self.layernorm(x)
525526

527+
class ModMultLinear(nn.Module):
528+
def __init__(self, w1_dim, w2_dim):
529+
super(ModMultLinear, self).__init__()
530+
self.linear1 = nn.Linear(5, w1_dim)
531+
self.linear2 = nn.Linear(5, w2_dim)
532+
self.linear3 = nn.Linear(w1_dim, 5)
533+
self.linear4 = nn.Linear(w1_dim, 5)
534+
535+
def forward(self, x):
536+
res1 = self.linear1(x)
537+
res2 = self.linear2(x)
538+
res3 = self.linear3(res1)
539+
res4 = self.linear4(res1)
540+
return res1, res2, res3, res4
541+
526542
class Tester(TestCase):
527543

528544
def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None, levels=['O0','O1']):
@@ -559,7 +575,6 @@ def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None, lev
559575
trace_graph = trace_fused_model.graph_for(x)
560576
fused_tresult = trace_fused_model(x)
561577
self.assertEqual(result, fused_tresult)
562-
563578
# check if the fused node exists in the graph
564579
if kind_in_graph is not None:
565580
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
@@ -632,7 +647,62 @@ def test_jit_freeze(self):
632647
self.assertTrue(all(n.kind() != node for n in freeze_graph.nodes()))
633648
# prepack op need note in none freeze model
634649
self.assertTrue(any(n.kind() == node for n in trace_graph.nodes()))
635-
650+
651+
def test_concat_linear(self):
652+
def check_op_count(graph_str, op_names=[]):
653+
count = 0
654+
node_list = graph_str.strip().split("\n")
655+
for node in node_list:
656+
for op_name in op_names:
657+
if op_name in node:
658+
count += 1
659+
return count
660+
origin_model = ModMultLinear(50, 60).eval()
661+
662+
test_val1 = torch.rand([50, 5])
663+
# call mkl path(fp32)
664+
model = ipex.optimize(origin_model, dtype=torch.float32)
665+
ori_res = model(test_val1)
666+
model_jit = torch.jit.trace(model,(test_val1))
667+
graph_ori = str(model_jit.graph_for(test_val1))
668+
linear_count_ori = check_op_count(graph_ori, ["aten::linear"])
669+
self.assertEqual(linear_count_ori, 4)
670+
model_jit = torch.jit.freeze(model_jit)
671+
jit_res = model_jit(test_val1)
672+
self.assertEqual(ori_res, jit_res)
673+
graph_opt = str(model_jit.graph_for(test_val1))
674+
linear_count_ori = check_op_count(graph_opt, ["aten::linear"])
675+
self.assertEqual(linear_count_ori, 2)
676+
# call onednn path(fp32)
677+
model = ipex.optimize(origin_model, dtype=torch.float32, auto_kernel_selection=True)
678+
ori_res = model(test_val1)
679+
model_jit = torch.jit.trace(model,(test_val1))
680+
graph_ori = str(model_jit.graph_for(test_val1))
681+
linear_count_ori = check_op_count(graph_ori, ["ipex_prepack::linear_run"])
682+
self.assertEqual(linear_count_ori, 4)
683+
model_jit = torch.jit.freeze(model_jit)
684+
jit_res = model_jit(test_val1)
685+
self.assertEqual(ori_res, jit_res)
686+
graph_opt = str(model_jit.graph_for(test_val1))
687+
linear_count_ori = check_op_count(graph_opt, ["ipex_prepack::linear_run"])
688+
self.assertEqual(linear_count_ori, 2)
689+
690+
model = ipex.optimize(origin_model, dtype=torch.bfloat16)
691+
test_val1 = test_val1.bfloat16()
692+
with torch.cpu.amp.autocast(), torch.no_grad():
693+
ori_res = model(test_val1)
694+
model_jit = torch.jit.trace(model,(test_val1))
695+
graph_ori = str(model_jit.graph_for(test_val1))
696+
linear_count_ori = check_op_count(graph_ori, ["ipex_prepack::linear_run"])
697+
self.assertEqual(linear_count_ori, 4)
698+
model_jit = torch.jit.freeze(model_jit)
699+
model_jit(test_val1)
700+
graph_opt = str(model_jit.graph_for(test_val1))
701+
jit_res = model_jit(test_val1)
702+
self.assertEqual(ori_res[1], jit_res[1])
703+
linear_count_ori = check_op_count(graph_opt, ["ipex_prepack::linear_run"])
704+
self.assertEqual(linear_count_ori, 2)
705+
636706
def test_add_layernorm(self):
637707
bs = 56
638708
seq_len = 384
@@ -647,7 +717,7 @@ def test_add_layernorm(self):
647717
self.assertEqual(jit_res, ori_res)
648718
node = "ipex::add_layernorm"
649719
self.assertTrue(any(n.kind() == node for n in trace_graph.nodes()))
650-
720+
651721
a_bf16 = a.to(torch.bfloat16)
652722
b_bf16 = b.to(torch.bfloat16)
653723
with torch.cpu.amp.autocast():
@@ -669,7 +739,6 @@ def test_add_layernorm(self):
669739
node = "ipex::add_layernorm"
670740
self.assertTrue(any(n.kind() == node for n in trace_graph.nodes()))
671741

672-
673742
def test_mha_scores_calculation(self):
674743
def _test_pure_bf16(model, trace_model, mat1, mat2, bias, prec=3e-2):
675744
mat1_bf16 = mat1.to(torch.bfloat16)
@@ -1027,7 +1096,7 @@ def _deconv_params_list():
10271096
"output_padding": [0], # TODO: fix output_padding >1.
10281097
"groups": [1, 2],
10291098
"dilation": [1, 2],
1030-
}
1099+
}
10311100

10321101
params_list = []
10331102

torch_ipex/csrc/cpu/Linear.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ at::Tensor ipex_linear(
237237
namespace {
238238

239239
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
240-
m.def("ipex_linear(Tensor input, Tensor weight, int out_features, int in_features, Tensor? bias_opt) -> Tensor", torch_ipex::cpu::ipex_linear);
240+
m.def(
241+
"ipex_linear(Tensor input, Tensor weight, int out_features, int in_features, Tensor? bias) -> Tensor",
242+
torch_ipex::cpu::ipex_linear);
241243
}
242244

243245
}

0 commit comments

Comments
 (0)