Skip to content

Commit a821c0a

Browse files
authored
Disable repack (#1563)
* allow disable re-pack by global flag (#1522) * enable linear fusion without jit repack * fix bug and add ut for re-pack flag * fix ut & add ut for linear fusion without repack * fix linear schema in concat_linear test * fix format * enable concat linear on ipex_linear and mkl_linear * Revert "enable concat linear on ipex_linear and mkl_linear" This reverts commit 68dd4561545be81bf9f9e07c065c9f17fefbc46c. * fix ut * add comments for why we still repack linear by default * format change
1 parent 1f1ee89 commit a821c0a

File tree

5 files changed

+228
-78
lines changed

5 files changed

+228
-78
lines changed

csrc/jit/auto_opt_config.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,26 @@ class TORCH_API AutoOptConfig {
1717
return jit_fuse_;
1818
}
1919

20+
inline void set_jit_repack_for_linear(bool jit_repack_for_linear) {
21+
jit_repack_for_linear_ = jit_repack_for_linear;
22+
}
23+
24+
inline bool get_jit_repack_for_linear() {
25+
return jit_repack_for_linear_;
26+
}
27+
2028
private:
2129
AutoOptConfig()
2230
: jit_fuse_(true),
31+
// jit repack (ipex linear -> aten linear -> ipex linear) will use
32+
// extra memory since the orinal graph will be always hold by design
33+
// https://github.com/pytorch/pytorch/blob/8e2a86c2a54719fd66a3e612fe8b433fbb1d4522/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp#L668
34+
// We use this flag to let custom disable repack to same meory
35+
// This is default False for 2 reasons:
36+
// (1) JIT repack stage can get a real input, so the block format
37+
// will be the best format. (2) Linear + binary cannot be folded if
38+
// we do not do repack, since it is implemented on aten:linear
39+
jit_repack_for_linear_(true),
2340
calibration_step_(false),
2441
qscheme_(at::QScheme::PER_TENSOR_AFFINE) {}
2542

@@ -28,6 +45,7 @@ class TORCH_API AutoOptConfig {
2845
AutoOptConfig& operator=(const AutoOptConfig&) = default;
2946

3047
bool jit_fuse_;
48+
bool jit_repack_for_linear_;
3149
// the flag for one iteration of calibration step whether end or not.
3250
bool calibration_step_;
3351
at::QScheme qscheme_;

csrc/jit/fusion_pass.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "fusion_pass.h"
22
#include <string>
3+
#include "auto_opt_config.h"
34
#include "codegen/onednn/interface.h"
45
#include "cpu/kernels/Matmul.h"
56
#include "passes/concat_linear.h"
@@ -132,8 +133,10 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
132133
// up fusion pass, will further abstract this as a class method.
133134
auto aten_linear_recorder = ATenLinearRecorder(graph);
134135
// linear folding
135-
graph_rewrite::replaceFrozenIPEXLinearWithAtenLinear(
136-
graph, aten_linear_recorder.use_mkl());
136+
if (AutoOptConfig::singleton().get_jit_repack_for_linear()) {
137+
graph_rewrite::replaceFrozenIPEXLinearWithAtenLinear(
138+
graph, aten_linear_recorder.use_mkl());
139+
}
137140
// concat multi-linear with same input
138141
torch_ipex::jit::FrozenConcatLinear(
139142
graph, aten_linear_recorder.get_records());

csrc/jit/passes/graph_rewrite_linear.cpp

Lines changed: 103 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ideep.hpp>
33
#include "passes/utils.h"
44

5+
#include "auto_opt_config.h"
56
#include "graph_rewrite.h"
67
#include "graph_rewrite_utils.h"
78

@@ -93,6 +94,101 @@ void replaceFrozenIPEXLinearWithAtenLinear(
9394
EliminateDeadCode(graph);
9495
}
9596

97+
void replaceAtenLinearWithPrepackNode(
98+
Node* n,
99+
std::unordered_set<Node*>& aten_linear,
100+
const bool& use_mkl_sgemm) {
101+
WithInsertPoint guard(n);
102+
auto graph = n->owningGraph();
103+
auto input_size_option =
104+
n->inputs().at(0)->type()->cast<TensorType>()->sizes().concrete_sizes();
105+
if (!(input_size_option.has_value() &&
106+
input_size_option.value().size() >= 2)) {
107+
return;
108+
}
109+
auto input_size = input_size_option.value();
110+
int64_t b_size =
111+
std::accumulate(
112+
input_size.begin(), input_size.end(), 1, std::multiplies<double>()) /
113+
input_size[input_size.size() - 1];
114+
IValue batch_size_value(b_size);
115+
auto batch_size = graph->insertConstant(batch_size_value);
116+
auto tt = n->inputs().at(1)->type()->cast<TensorType>();
117+
auto weight_size_option = tt->sizes().concrete_sizes();
118+
if (!(weight_size_option.has_value() &&
119+
weight_size_option.value().size() == 2)) {
120+
return;
121+
}
122+
auto weight_dtype_option = tt->scalarType();
123+
bool should_repack = aten_linear.find(n) == aten_linear.end() &&
124+
AutoOptConfig::singleton().get_jit_repack_for_linear();
125+
126+
// we should pack aten linear to ipex prepack linear for 2 cases:
127+
// (1): Repack case, this aten linear is created by ipex linear
128+
// (2) BF16 case, we believe IPEX BF16 prepack linear always better than aten
129+
// BF16 linear
130+
bool should_pack_for_bf16 = weight_dtype_option.has_value() &&
131+
weight_dtype_option.value() == at::ScalarType::BFloat16 &&
132+
ideep::has_bf16_type_support();
133+
bool should_pack = should_repack || should_pack_for_bf16;
134+
if (!(should_pack))
135+
return;
136+
137+
auto weight_size = weight_size_option.value();
138+
139+
// Note that once creating a graph node, make sure it is also inserted into
140+
// the graph, for: PyTorch (when disabled TE) has a check on the graph node,
141+
// pointing out that every mutable value in the system has a corresponding
142+
// element. So if creating a graph node but not inserted, it will not pass
143+
// the check since its graph element is not initialized. Details please
144+
// refer to
145+
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/alias_analysis.cpp#L1956
146+
auto use_mkl_sgemm_ =
147+
use_mkl_sgemm && weight_dtype_option.value() != at::ScalarType::BFloat16;
148+
auto prepack_node = graph->create(
149+
use_mkl_sgemm_ ? Symbol::fromQualString("ipex_prepack::mkl_sgemm_prepack")
150+
: Symbol::fromQualString("ipex_prepack::linear_prepack"),
151+
1);
152+
for (auto i = 1; i < n->inputs().size(); ++i) {
153+
Value* v = n->inputs().at(i);
154+
prepack_node->addInput(v);
155+
}
156+
prepack_node->addInput(batch_size);
157+
prepack_node->output()->setType(
158+
use_mkl_sgemm_
159+
? getCustomClass("__torch__.torch.classes.ipex_prepack.MKLOpContext")
160+
: getCustomClass(
161+
"__torch__.torch.classes.ipex_prepack.LinearOpContext"));
162+
graph->insertNode(prepack_node);
163+
auto prepack_linear = graph->insertNode(graph->create(
164+
use_mkl_sgemm_ ? Symbol::fromQualString("ipex_prepack::mkl_sgemm_run")
165+
: Symbol::fromQualString("ipex_prepack::linear_run"),
166+
1));
167+
prepack_linear->addInput(n->inputs().at(0));
168+
prepack_linear->addInput(prepack_node->output());
169+
prepack_linear->output()->setType(n->output()->type()->cast<TensorType>());
170+
auto v = n->outputs().at(0);
171+
n->output()->replaceAllUsesWith(prepack_linear->output());
172+
}
173+
174+
void replaceIpexLinearWithLinearRunNode(Node* n) {
175+
WithInsertPoint guard(n);
176+
auto graph = n->owningGraph();
177+
auto use_mkl_sgemm =
178+
n->kind() == Symbol::fromQualString("torch_ipex::ipex_MKLSGEMM");
179+
auto get_data_handle_node = n->inputs().at(3)->node();
180+
auto linear_ctx = get_data_handle_node->inputs().at(0);
181+
auto linear_run = graph->insertNode(graph->create(
182+
use_mkl_sgemm ? Symbol::fromQualString("ipex_prepack::mkl_sgemm_run")
183+
: Symbol::fromQualString("ipex_prepack::linear_run"),
184+
1));
185+
linear_run->addInput(n->inputs().at(0));
186+
linear_run->addInput(linear_ctx);
187+
linear_run->output()->setType(n->output()->type()->cast<TensorType>());
188+
n->output()->replaceAllUsesWith(linear_run->output());
189+
return;
190+
}
191+
96192
void insertPrePackedLinearOp(
97193
Block* b,
98194
std::unordered_set<Node*>& aten_linear,
@@ -101,75 +197,15 @@ void insertPrePackedLinearOp(
101197
for (Block* block : n->blocks()) {
102198
insertPrePackedLinearOp(block, aten_linear, use_mkl_sgemm);
103199
}
104-
if (n->kind() != aten::linear)
105-
continue;
106-
WithInsertPoint guard(n);
107-
auto graph = n->owningGraph();
108-
auto input_size_option =
109-
n->inputs().at(0)->type()->cast<TensorType>()->sizes().concrete_sizes();
110-
if (!(input_size_option.has_value() &&
111-
input_size_option.value().size() >= 2)) {
112-
continue;
113-
}
114-
auto input_size = input_size_option.value();
115-
int64_t b_size = std::accumulate(
116-
input_size.begin(),
117-
input_size.end(),
118-
1,
119-
std::multiplies<double>()) /
120-
input_size[input_size.size() - 1];
121-
IValue batch_size_value(b_size);
122-
auto batch_size = graph->insertConstant(batch_size_value);
123-
auto tt = n->inputs().at(1)->type()->cast<TensorType>();
124-
auto weight_size_option = tt->sizes().concrete_sizes();
125-
if (!(weight_size_option.has_value() &&
126-
weight_size_option.value().size() == 2)) {
127-
continue;
128-
}
129-
auto weight_dtype_option = tt->scalarType();
130-
if (!(weight_dtype_option.has_value() &&
131-
(weight_dtype_option.value() == at::ScalarType::BFloat16) &&
132-
ideep::has_bf16_type_support() ||
133-
aten_linear.find(n) == aten_linear.end())) {
200+
if (n->kind() == aten::linear) {
201+
replaceAtenLinearWithPrepackNode(n, aten_linear, use_mkl_sgemm);
202+
} else if (
203+
n->kind() == Symbol::fromQualString("torch_ipex::ipex_linear") ||
204+
n->kind() == Symbol::fromQualString("torch_ipex::ipex_MKLSGEMM")) {
205+
replaceIpexLinearWithLinearRunNode(n);
206+
} else {
134207
continue;
135208
}
136-
auto weight_size = weight_size_option.value();
137-
138-
// Note that once creating a graph node, make sure it is also inserted into
139-
// the graph, for: PyTorch (when disabled TE) has a check on the graph node,
140-
// pointing out that every mutable value in the system has a corresponding
141-
// element. So if creating a graph node but not inserted, it will not pass
142-
// the check since its graph element is not initialized. Details please
143-
// refer to
144-
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/alias_analysis.cpp#L1956
145-
auto use_mkl_sgemm_ = use_mkl_sgemm &&
146-
weight_dtype_option.value() != at::ScalarType::BFloat16;
147-
auto prepack_node = graph->create(
148-
use_mkl_sgemm_
149-
? Symbol::fromQualString("ipex_prepack::mkl_sgemm_prepack")
150-
: Symbol::fromQualString("ipex_prepack::linear_prepack"),
151-
1);
152-
for (auto i = 1; i < n->inputs().size(); ++i) {
153-
Value* v = n->inputs().at(i);
154-
prepack_node->addInput(v);
155-
}
156-
prepack_node->addInput(batch_size);
157-
prepack_node->output()->setType(
158-
use_mkl_sgemm_
159-
? getCustomClass(
160-
"__torch__.torch.classes.ipex_prepack.MKLOpContext")
161-
: getCustomClass(
162-
"__torch__.torch.classes.ipex_prepack.LinearOpContext"));
163-
graph->insertNode(prepack_node);
164-
auto prepack_linear = graph->insertNode(graph->create(
165-
use_mkl_sgemm_ ? Symbol::fromQualString("ipex_prepack::mkl_sgemm_run")
166-
: Symbol::fromQualString("ipex_prepack::linear_run"),
167-
1));
168-
prepack_linear->addInput(n->inputs().at(0));
169-
prepack_linear->addInput(prepack_node->output());
170-
prepack_linear->output()->setType(n->output()->type()->cast<TensorType>());
171-
auto v = n->outputs().at(0);
172-
n->output()->replaceAllUsesWith(prepack_linear->output());
173209
}
174210
EliminateDeadCode(b);
175211
}

intel_extension_for_pytorch/csrc/cpu/Module.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ void InitIpexModuleBindings(py::module m) {
141141
return AutoOptConfig::singleton().get_jit_fuse();
142142
});
143143

144+
m.def("enable_jit_linear_repack", []() {
145+
AutoOptConfig::singleton().set_jit_repack_for_linear(true);
146+
});
147+
m.def("disable_jit_linear_repack", []() {
148+
AutoOptConfig::singleton().set_jit_repack_for_linear(false);
149+
});
150+
m.def("get_jit_linear_repack", []() {
151+
return AutoOptConfig::singleton().get_jit_repack_for_linear();
152+
});
153+
144154
// BF32
145155
py::enum_<FP32MathMode>(m, "FP32MathMode")
146156
.value("FP32", FP32MathMode::FP32)

0 commit comments

Comments
 (0)