@@ -289,7 +289,8 @@ def __init__(self, in_channels, out_channels, **kwargs):
289
289
self .linear1 = nn .Linear (in_channels , out_channels , ** kwargs )
290
290
291
291
def forward (self , x ):
292
- return torch .add (self .linear (x ),self .linear1 (x ))
292
+ x1 = x .clone ()
293
+ return torch .add (self .linear (x ),self .linear1 (x1 ))
293
294
294
295
class Linear_Reshape_Relu (nn .Module ):
295
296
def __init__ (self , in_channels , out_channels ,dest_shape , ** kwargs ):
@@ -523,6 +524,21 @@ def forward(self, x, y, z):
523
524
x = x + y + z
524
525
return self .layernorm (x )
525
526
527
+ class ModMultLinear (nn .Module ):
528
+ def __init__ (self , w1_dim , w2_dim ):
529
+ super (ModMultLinear , self ).__init__ ()
530
+ self .linear1 = nn .Linear (5 , w1_dim )
531
+ self .linear2 = nn .Linear (5 , w2_dim )
532
+ self .linear3 = nn .Linear (w1_dim , 5 )
533
+ self .linear4 = nn .Linear (w1_dim , 5 )
534
+
535
+ def forward (self , x ):
536
+ res1 = self .linear1 (x )
537
+ res2 = self .linear2 (x )
538
+ res3 = self .linear3 (res1 )
539
+ res4 = self .linear4 (res1 )
540
+ return res1 , res2 , res3 , res4
541
+
526
542
class Tester (TestCase ):
527
543
528
544
def _test_output (self , model , x , kind_in_graph = None , kind_not_in_graph = None , levels = ['O0' ,'O1' ]):
@@ -559,7 +575,6 @@ def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None, lev
559
575
trace_graph = trace_fused_model .graph_for (x )
560
576
fused_tresult = trace_fused_model (x )
561
577
self .assertEqual (result , fused_tresult )
562
-
563
578
# check if the fused node exists in the graph
564
579
if kind_in_graph is not None :
565
580
self .assertTrue (any (n .kind () == kind_in_graph for n in trace_graph .nodes ()))
@@ -632,7 +647,62 @@ def test_jit_freeze(self):
632
647
self .assertTrue (all (n .kind () != node for n in freeze_graph .nodes ()))
633
648
# prepack op need note in none freeze model
634
649
self .assertTrue (any (n .kind () == node for n in trace_graph .nodes ()))
635
-
650
+
651
+ def test_concat_linear (self ):
652
+ def check_op_count (graph_str , op_names = []):
653
+ count = 0
654
+ node_list = graph_str .strip ().split ("\n " )
655
+ for node in node_list :
656
+ for op_name in op_names :
657
+ if op_name in node :
658
+ count += 1
659
+ return count
660
+ origin_model = ModMultLinear (50 , 60 ).eval ()
661
+
662
+ test_val1 = torch .rand ([50 , 5 ])
663
+ # call mkl path(fp32)
664
+ model = ipex .optimize (origin_model , dtype = torch .float32 )
665
+ ori_res = model (test_val1 )
666
+ model_jit = torch .jit .trace (model ,(test_val1 ))
667
+ graph_ori = str (model_jit .graph_for (test_val1 ))
668
+ linear_count_ori = check_op_count (graph_ori , ["aten::linear" ])
669
+ self .assertEqual (linear_count_ori , 4 )
670
+ model_jit = torch .jit .freeze (model_jit )
671
+ jit_res = model_jit (test_val1 )
672
+ self .assertEqual (ori_res , jit_res )
673
+ graph_opt = str (model_jit .graph_for (test_val1 ))
674
+ linear_count_ori = check_op_count (graph_opt , ["aten::linear" ])
675
+ self .assertEqual (linear_count_ori , 2 )
676
+ # call onednn path(fp32)
677
+ model = ipex .optimize (origin_model , dtype = torch .float32 , auto_kernel_selection = True )
678
+ ori_res = model (test_val1 )
679
+ model_jit = torch .jit .trace (model ,(test_val1 ))
680
+ graph_ori = str (model_jit .graph_for (test_val1 ))
681
+ linear_count_ori = check_op_count (graph_ori , ["ipex_prepack::linear_run" ])
682
+ self .assertEqual (linear_count_ori , 4 )
683
+ model_jit = torch .jit .freeze (model_jit )
684
+ jit_res = model_jit (test_val1 )
685
+ self .assertEqual (ori_res , jit_res )
686
+ graph_opt = str (model_jit .graph_for (test_val1 ))
687
+ linear_count_ori = check_op_count (graph_opt , ["ipex_prepack::linear_run" ])
688
+ self .assertEqual (linear_count_ori , 2 )
689
+
690
+ model = ipex .optimize (origin_model , dtype = torch .bfloat16 )
691
+ test_val1 = test_val1 .bfloat16 ()
692
+ with torch .cpu .amp .autocast (), torch .no_grad ():
693
+ ori_res = model (test_val1 )
694
+ model_jit = torch .jit .trace (model ,(test_val1 ))
695
+ graph_ori = str (model_jit .graph_for (test_val1 ))
696
+ linear_count_ori = check_op_count (graph_ori , ["ipex_prepack::linear_run" ])
697
+ self .assertEqual (linear_count_ori , 4 )
698
+ model_jit = torch .jit .freeze (model_jit )
699
+ model_jit (test_val1 )
700
+ graph_opt = str (model_jit .graph_for (test_val1 ))
701
+ jit_res = model_jit (test_val1 )
702
+ self .assertEqual (ori_res [1 ], jit_res [1 ])
703
+ linear_count_ori = check_op_count (graph_opt , ["ipex_prepack::linear_run" ])
704
+ self .assertEqual (linear_count_ori , 2 )
705
+
636
706
def test_add_layernorm (self ):
637
707
bs = 56
638
708
seq_len = 384
@@ -647,7 +717,7 @@ def test_add_layernorm(self):
647
717
self .assertEqual (jit_res , ori_res )
648
718
node = "ipex::add_layernorm"
649
719
self .assertTrue (any (n .kind () == node for n in trace_graph .nodes ()))
650
-
720
+
651
721
a_bf16 = a .to (torch .bfloat16 )
652
722
b_bf16 = b .to (torch .bfloat16 )
653
723
with torch .cpu .amp .autocast ():
@@ -669,7 +739,6 @@ def test_add_layernorm(self):
669
739
node = "ipex::add_layernorm"
670
740
self .assertTrue (any (n .kind () == node for n in trace_graph .nodes ()))
671
741
672
-
673
742
def test_mha_scores_calculation (self ):
674
743
def _test_pure_bf16 (model , trace_model , mat1 , mat2 , bias , prec = 3e-2 ):
675
744
mat1_bf16 = mat1 .to (torch .bfloat16 )
@@ -1027,7 +1096,7 @@ def _deconv_params_list():
1027
1096
"output_padding" : [0 ], # TODO: fix output_padding >1.
1028
1097
"groups" : [1 , 2 ],
1029
1098
"dilation" : [1 , 2 ],
1030
- }
1099
+ }
1031
1100
1032
1101
params_list = []
1033
1102
0 commit comments