2
2
#include < ideep.hpp>
3
3
#include " passes/utils.h"
4
4
5
+ #include " auto_opt_config.h"
5
6
#include " graph_rewrite.h"
6
7
#include " graph_rewrite_utils.h"
7
8
@@ -93,6 +94,101 @@ void replaceFrozenIPEXLinearWithAtenLinear(
93
94
EliminateDeadCode (graph);
94
95
}
95
96
97
+ void replaceAtenLinearWithPrepackNode (
98
+ Node* n,
99
+ std::unordered_set<Node*>& aten_linear,
100
+ const bool & use_mkl_sgemm) {
101
+ WithInsertPoint guard (n);
102
+ auto graph = n->owningGraph ();
103
+ auto input_size_option =
104
+ n->inputs ().at (0 )->type ()->cast <TensorType>()->sizes ().concrete_sizes ();
105
+ if (!(input_size_option.has_value () &&
106
+ input_size_option.value ().size () >= 2 )) {
107
+ return ;
108
+ }
109
+ auto input_size = input_size_option.value ();
110
+ int64_t b_size =
111
+ std::accumulate (
112
+ input_size.begin (), input_size.end (), 1 , std::multiplies<double >()) /
113
+ input_size[input_size.size () - 1 ];
114
+ IValue batch_size_value (b_size);
115
+ auto batch_size = graph->insertConstant (batch_size_value);
116
+ auto tt = n->inputs ().at (1 )->type ()->cast <TensorType>();
117
+ auto weight_size_option = tt->sizes ().concrete_sizes ();
118
+ if (!(weight_size_option.has_value () &&
119
+ weight_size_option.value ().size () == 2 )) {
120
+ return ;
121
+ }
122
+ auto weight_dtype_option = tt->scalarType ();
123
+ bool should_repack = aten_linear.find (n) == aten_linear.end () &&
124
+ AutoOptConfig::singleton ().get_jit_repack_for_linear ();
125
+
126
+ // we should pack aten linear to ipex prepack linear for 2 cases:
127
+ // (1): Repack case, this aten linear is created by ipex linear
128
+ // (2) BF16 case, we believe IPEX BF16 prepack linear always better than aten
129
+ // BF16 linear
130
+ bool should_pack_for_bf16 = weight_dtype_option.has_value () &&
131
+ weight_dtype_option.value () == at::ScalarType::BFloat16 &&
132
+ ideep::has_bf16_type_support ();
133
+ bool should_pack = should_repack || should_pack_for_bf16;
134
+ if (!(should_pack))
135
+ return ;
136
+
137
+ auto weight_size = weight_size_option.value ();
138
+
139
+ // Note that once creating a graph node, make sure it is also inserted into
140
+ // the graph, for: PyTorch (when disabled TE) has a check on the graph node,
141
+ // pointing out that every mutable value in the system has a corresponding
142
+ // element. So if creating a graph node but not inserted, it will not pass
143
+ // the check since its graph element is not initialized. Details please
144
+ // refer to
145
+ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/alias_analysis.cpp#L1956
146
+ auto use_mkl_sgemm_ =
147
+ use_mkl_sgemm && weight_dtype_option.value () != at::ScalarType::BFloat16;
148
+ auto prepack_node = graph->create (
149
+ use_mkl_sgemm_ ? Symbol::fromQualString (" ipex_prepack::mkl_sgemm_prepack" )
150
+ : Symbol::fromQualString (" ipex_prepack::linear_prepack" ),
151
+ 1 );
152
+ for (auto i = 1 ; i < n->inputs ().size (); ++i) {
153
+ Value* v = n->inputs ().at (i);
154
+ prepack_node->addInput (v);
155
+ }
156
+ prepack_node->addInput (batch_size);
157
+ prepack_node->output ()->setType (
158
+ use_mkl_sgemm_
159
+ ? getCustomClass (" __torch__.torch.classes.ipex_prepack.MKLOpContext" )
160
+ : getCustomClass (
161
+ " __torch__.torch.classes.ipex_prepack.LinearOpContext" ));
162
+ graph->insertNode (prepack_node);
163
+ auto prepack_linear = graph->insertNode (graph->create (
164
+ use_mkl_sgemm_ ? Symbol::fromQualString (" ipex_prepack::mkl_sgemm_run" )
165
+ : Symbol::fromQualString (" ipex_prepack::linear_run" ),
166
+ 1 ));
167
+ prepack_linear->addInput (n->inputs ().at (0 ));
168
+ prepack_linear->addInput (prepack_node->output ());
169
+ prepack_linear->output ()->setType (n->output ()->type ()->cast <TensorType>());
170
+ auto v = n->outputs ().at (0 );
171
+ n->output ()->replaceAllUsesWith (prepack_linear->output ());
172
+ }
173
+
174
+ void replaceIpexLinearWithLinearRunNode (Node* n) {
175
+ WithInsertPoint guard (n);
176
+ auto graph = n->owningGraph ();
177
+ auto use_mkl_sgemm =
178
+ n->kind () == Symbol::fromQualString (" torch_ipex::ipex_MKLSGEMM" );
179
+ auto get_data_handle_node = n->inputs ().at (3 )->node ();
180
+ auto linear_ctx = get_data_handle_node->inputs ().at (0 );
181
+ auto linear_run = graph->insertNode (graph->create (
182
+ use_mkl_sgemm ? Symbol::fromQualString (" ipex_prepack::mkl_sgemm_run" )
183
+ : Symbol::fromQualString (" ipex_prepack::linear_run" ),
184
+ 1 ));
185
+ linear_run->addInput (n->inputs ().at (0 ));
186
+ linear_run->addInput (linear_ctx);
187
+ linear_run->output ()->setType (n->output ()->type ()->cast <TensorType>());
188
+ n->output ()->replaceAllUsesWith (linear_run->output ());
189
+ return ;
190
+ }
191
+
96
192
void insertPrePackedLinearOp (
97
193
Block* b,
98
194
std::unordered_set<Node*>& aten_linear,
@@ -101,75 +197,15 @@ void insertPrePackedLinearOp(
101
197
for (Block* block : n->blocks ()) {
102
198
insertPrePackedLinearOp (block, aten_linear, use_mkl_sgemm);
103
199
}
104
- if (n->kind () != aten::linear)
105
- continue ;
106
- WithInsertPoint guard (n);
107
- auto graph = n->owningGraph ();
108
- auto input_size_option =
109
- n->inputs ().at (0 )->type ()->cast <TensorType>()->sizes ().concrete_sizes ();
110
- if (!(input_size_option.has_value () &&
111
- input_size_option.value ().size () >= 2 )) {
112
- continue ;
113
- }
114
- auto input_size = input_size_option.value ();
115
- int64_t b_size = std::accumulate (
116
- input_size.begin (),
117
- input_size.end (),
118
- 1 ,
119
- std::multiplies<double >()) /
120
- input_size[input_size.size () - 1 ];
121
- IValue batch_size_value (b_size);
122
- auto batch_size = graph->insertConstant (batch_size_value);
123
- auto tt = n->inputs ().at (1 )->type ()->cast <TensorType>();
124
- auto weight_size_option = tt->sizes ().concrete_sizes ();
125
- if (!(weight_size_option.has_value () &&
126
- weight_size_option.value ().size () == 2 )) {
127
- continue ;
128
- }
129
- auto weight_dtype_option = tt->scalarType ();
130
- if (!(weight_dtype_option.has_value () &&
131
- (weight_dtype_option.value () == at::ScalarType::BFloat16) &&
132
- ideep::has_bf16_type_support () ||
133
- aten_linear.find (n) == aten_linear.end ())) {
200
+ if (n->kind () == aten::linear) {
201
+ replaceAtenLinearWithPrepackNode (n, aten_linear, use_mkl_sgemm);
202
+ } else if (
203
+ n->kind () == Symbol::fromQualString (" torch_ipex::ipex_linear" ) ||
204
+ n->kind () == Symbol::fromQualString (" torch_ipex::ipex_MKLSGEMM" )) {
205
+ replaceIpexLinearWithLinearRunNode (n);
206
+ } else {
134
207
continue ;
135
208
}
136
- auto weight_size = weight_size_option.value ();
137
-
138
- // Note that once creating a graph node, make sure it is also inserted into
139
- // the graph, for: PyTorch (when disabled TE) has a check on the graph node,
140
- // pointing out that every mutable value in the system has a corresponding
141
- // element. So if creating a graph node but not inserted, it will not pass
142
- // the check since its graph element is not initialized. Details please
143
- // refer to
144
- // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/alias_analysis.cpp#L1956
145
- auto use_mkl_sgemm_ = use_mkl_sgemm &&
146
- weight_dtype_option.value () != at::ScalarType::BFloat16;
147
- auto prepack_node = graph->create (
148
- use_mkl_sgemm_
149
- ? Symbol::fromQualString (" ipex_prepack::mkl_sgemm_prepack" )
150
- : Symbol::fromQualString (" ipex_prepack::linear_prepack" ),
151
- 1 );
152
- for (auto i = 1 ; i < n->inputs ().size (); ++i) {
153
- Value* v = n->inputs ().at (i);
154
- prepack_node->addInput (v);
155
- }
156
- prepack_node->addInput (batch_size);
157
- prepack_node->output ()->setType (
158
- use_mkl_sgemm_
159
- ? getCustomClass (
160
- " __torch__.torch.classes.ipex_prepack.MKLOpContext" )
161
- : getCustomClass (
162
- " __torch__.torch.classes.ipex_prepack.LinearOpContext" ));
163
- graph->insertNode (prepack_node);
164
- auto prepack_linear = graph->insertNode (graph->create (
165
- use_mkl_sgemm_ ? Symbol::fromQualString (" ipex_prepack::mkl_sgemm_run" )
166
- : Symbol::fromQualString (" ipex_prepack::linear_run" ),
167
- 1 ));
168
- prepack_linear->addInput (n->inputs ().at (0 ));
169
- prepack_linear->addInput (prepack_node->output ());
170
- prepack_linear->output ()->setType (n->output ()->type ()->cast <TensorType>());
171
- auto v = n->outputs ().at (0 );
172
- n->output ()->replaceAllUsesWith (prepack_linear->output ());
173
209
}
174
210
EliminateDeadCode (b);
175
211
}
0 commit comments