Skip to content

Commit 30b70e4

Browse files
Support dict input for quantization prepare (#1682)
* Enable dict input for ipex quantization prepare * code format * add UT * code format * code clean up
1 parent c0daaa5 commit 30b70e4

File tree

4 files changed

+184
-21
lines changed

4 files changed

+184
-21
lines changed

intel_extension_for_pytorch/quantization/_quantize.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,32 @@
1818
)
1919
from ._quantize_utils import auto_prepare, auto_convert, copy_prepared_model
2020
from .. import nn
21+
from typing import Dict
2122

2223

23-
def prepare(model, configure, example_inputs=None, inplace=False, bn_folding=True):
24+
def prepare(
25+
model,
26+
configure,
27+
example_inputs=None,
28+
inplace=False,
29+
bn_folding=True,
30+
example_kwarg_inputs=None,
31+
):
2432
r"""
2533
Prepare an FP32 torch.nn.Module model to do calibration or to convert to quantized model.
2634
2735
Args:
2836
model (torch.nn.Module): The FP32 model to be prepared.
2937
configure (torch.quantization.qconfig.QConfig): The observer settings about activation and weight.
3038
example_inputs (tuple or torch.Tensor): A tuple of example inputs that
31-
will be passed to the function while running to init quantization state.
39+
will be passed to the function while running to init quantization state. Only one of this
40+
argument or ``example_kwarg_inputs`` should be specified.
3241
inplace: (bool): It will change the given model in-place if True. The default value is ``False``.
3342
bn_folding: (bool): whether to perform ``conv_bn`` and ``linear_bn`` folding.
34-
The default value is ``True``.
43+
The default value is ``True``.
44+
example_kwarg_inputs (dict): A dict of example inputs that will be passed to the function while
45+
running to init quantization state. Only one of this argument or ``example_inputs`` should be
46+
specified.
3547
3648
Returns:
3749
torch.nn.Module
@@ -52,9 +64,10 @@ def prepare(model, configure, example_inputs=None, inplace=False, bn_folding=Tru
5264
if isinstance(configure, QConfigMapping):
5365
configure = configure.global_qconfig
5466
if not isinstance(configure.activation(), PlaceholderObserver):
55-
assert (
56-
example_inputs is not None
57-
), "IPEX quantization.prepare: example inputs cannot be None for static quantization"
67+
assert example_inputs is not None or example_kwarg_inputs is not None, (
68+
"IPEX quantization.prepare: example_inputs and example_kwarg_inputs cannot be none at same time "
69+
"for static quantization."
70+
)
5871
# auto model channels_last memory format conversion
5972
from ..frontend import (
6073
auto_channels_last,
@@ -81,12 +94,19 @@ def prepare(model, configure, example_inputs=None, inplace=False, bn_folding=Tru
8194

8295
# replace dropout with identity to enable more fusion pattern.
8396
nn.utils._model_convert.replace_dropout_with_identity(prepare_model)
97+
assert (
98+
example_inputs is None or example_kwarg_inputs is None
99+
), "IPEX quantization.prepare: example_inputs and example_kwarg_inputs cannot be set at same time."
84100
# Special case for common case of passing a single Tensor
85101
if isinstance(example_inputs, (torch.Tensor, dict)):
86102
example_inputs = (example_inputs,)
87103
elif not isinstance(example_inputs, tuple) and example_inputs is not None:
88104
example_inputs = tuple(example_inputs)
89-
return auto_prepare(prepare_model, configure, example_inputs)
105+
if example_kwarg_inputs is not None:
106+
assert isinstance(
107+
example_kwarg_inputs, Dict
108+
), "IPEX quantization.prepare: example_kwarg_inputs must be type of Dict."
109+
return auto_prepare(prepare_model, configure, example_inputs, example_kwarg_inputs)
90110

91111

92112
@functools.lru_cache(None)

intel_extension_for_pytorch/quantization/_quantize_utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def auto_prepare(
7373
model: torch.nn.Module,
7474
configure: QConfig,
7575
example_inputs: Optional[Tuple[Any]],
76+
example_kwarg_inputs: Optional[Dict[Any, Any]],
7677
) -> torch.nn.Module:
7778
def convert_to_interception_proxy(x):
7879
if isinstance(x, torch.Tensor):
@@ -486,10 +487,20 @@ def load_qconf_summary(self, qconf_summary):
486487
if not isinstance(configure.activation(), PlaceholderObserver):
487488
model.__class__ = QuantizationInterceptionModule
488489
# init model quantization state using example_inputs
489-
assert (
490-
example_inputs is not None
491-
), "IPEX: example inputs cannot be None for static quantization"
492-
model(*example_inputs)
490+
assert example_inputs is not None or example_kwarg_inputs is not None, (
491+
"IPEX: example_inputs and example_kwarg_inputs cannot be None at same time "
492+
"for static quantization."
493+
)
494+
if example_kwarg_inputs is None:
495+
model(*example_inputs)
496+
elif example_inputs is None:
497+
model(**example_kwarg_inputs)
498+
else:
499+
AssertionError(
500+
False,
501+
"IPEX quantization.prepare: example_inputs and example_kwarg_inputs cannot be set at same time "
502+
"for static quantization.",
503+
)
493504
return model
494505

495506

tests/cpu/test_ao_jit_ipex_quantization.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
QConfig,
2626
PlaceholderObserver,
2727
)
28+
from torch.testing._internal.common_utils import run_tests
2829

2930
default_weight_observer = PerChannelMinMaxObserver.with_args(
3031
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
@@ -748,5 +749,89 @@ def forward(self, x, hx, cx):
748749
FileCheck().check_not("aten:lstm").check("aten::quantized_lstm").run(graph)
749750

750751

752+
class TestDictInput(JitLlgaTestCase):
753+
def test_only_dict_input(self):
754+
class SubModule(nn.Module):
755+
def __init__(self):
756+
super(SubModule, self).__init__()
757+
self.linear = nn.Linear(3, 3)
758+
759+
def forward(self, x):
760+
x = self.linear(x)
761+
return x
762+
763+
class M(nn.Module):
764+
def __init__(self):
765+
super(M, self).__init__()
766+
self.linear1 = nn.Sequential(nn.Linear(3, 3))
767+
self.linear2 = SubModule()
768+
self.linear3 = nn.Linear(3, 3)
769+
770+
def forward(self, x1, x2, x3):
771+
x1 = self.linear1(x1)
772+
x2 = self.linear2(x2)
773+
x3 = self.linear3(x3)
774+
return x1 + x2 + x3
775+
776+
int8_bf16_list = [True, False]
777+
for qconfig, int8_bf16 in itertools.product(static_qconfig, int8_bf16_list):
778+
# Step1: Test model with tuple(x1, x2, x3) input.
779+
m = M().eval()
780+
m2 = copy.deepcopy(m).eval()
781+
x1 = torch.randn(3, 3)
782+
x2 = torch.randn(3, 3)
783+
x3 = torch.randn(3, 3)
784+
graph = self.checkQuantizeTrace(
785+
m, [x1, x2, x3], atol=2e-1, qconfig=qconfig, int8_bf16=int8_bf16
786+
)
787+
FileCheck().check("aten::linear").run(graph)
788+
patterns = [
789+
[
790+
"aten::dequantize",
791+
"aten::linear",
792+
],
793+
[
794+
"aten::dequantize",
795+
"aten::linear",
796+
"aten::add",
797+
],
798+
[
799+
"aten::dequantize",
800+
"aten::linear",
801+
"aten::add",
802+
],
803+
]
804+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
805+
self.checkPatterns(graph, patterns)
806+
807+
# Step2: Test model with Dict{"x1": x1, "x2": x2, "x3": x3} input.
808+
graph = self.checkQuantizeTrace(
809+
m2,
810+
atol=2e-1,
811+
qconfig=qconfig,
812+
int8_bf16=int8_bf16,
813+
x_kwarg={"x1": x1, "x2": x2, "x3": x3},
814+
)
815+
FileCheck().check("aten::linear").run(graph)
816+
patterns = [
817+
[
818+
"aten::dequantize",
819+
"aten::linear",
820+
],
821+
[
822+
"aten::dequantize",
823+
"aten::linear",
824+
"aten::add",
825+
],
826+
[
827+
"aten::dequantize",
828+
"aten::linear",
829+
"aten::add",
830+
],
831+
]
832+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3)
833+
self.checkPatterns(graph, patterns)
834+
835+
751836
if __name__ == "__main__":
752837
run_tests()

tests/cpu/test_ao_jit_llga_utils.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,28 +123,58 @@ def assertFused(self, graph, fused_patterns):
123123
for pat in fused_patterns:
124124
self.assertGraphContainsExactly(graph, pat, 0)
125125

126+
def model_forward_helper(
127+
self,
128+
model,
129+
x=None,
130+
x_kwarg=None,
131+
):
132+
if x is None and x_kwarg is None:
133+
raise AssertionError(
134+
"x and x_kwarg cannot be none at same time for model_forward_helper."
135+
)
136+
if x_kwarg is None:
137+
return model(*x)
138+
elif x is None:
139+
return model(**x_kwarg)
140+
else:
141+
raise AssertionError(
142+
"x and x_kwarg cannot be set at same time for model_forward_helper."
143+
)
144+
126145
def checkQuantizeTrace(
127146
self,
128147
model,
129-
x,
148+
x=None,
130149
atol=1e-3,
131150
rtol=1e-2,
132151
x_var=None,
133152
qconfig=default_static_qconfig,
134153
int8_bf16=False,
135154
freeze=True,
155+
x_kwarg=None,
136156
):
157+
if x is None and x_kwarg is None:
158+
raise AssertionError(
159+
"x and x_kwarg cannot be none at same time for checkQuantizeTrace."
160+
)
161+
elif x is not None and x_kwarg is not None:
162+
raise AssertionError(
163+
"x and x_kwarg cannot be set at same time for checkQuantizeTrace."
164+
)
165+
137166
graph, traced_model, fp32_model = self.prepareModel(
138-
model, x, qconfig, int8_bf16, freeze=freeze
167+
model, x, qconfig, int8_bf16, freeze=freeze, x_kwarg=x_kwarg
139168
)
140169
with torch.no_grad():
141-
y = fp32_model(*x)
170+
y = self.model_forward_helper(fp32_model, x, x_kwarg)
142171
y = y.to(torch.bfloat16) if int8_bf16 else y
143-
y_llga = traced_model(*x)
172+
y_llga = self.model_forward_helper(traced_model, x, x_kwarg)
144173
self.assertEqual(y, y_llga, atol=atol, rtol=rtol)
145174

146175
# test Fallback when input shape changes:
147176
if x_var:
177+
assert x_kwarg is None, "x_kwarg input doesn't suppport use with x_var"
148178
y_var = fp32_model(*x_var)
149179
y_var = y_var.to(torch.bfloat16) if int8_bf16 else y_var
150180
y_var_llga = traced_model(*x_var)
@@ -161,35 +191,52 @@ def prepareModel(
161191
prepare_inplace=True,
162192
convert_inplace=True,
163193
freeze=True,
194+
x_kwarg=None,
164195
):
165196
model.eval()
166197
fp32_model = copy.deepcopy(model)
167198
with torch.no_grad(), torch._jit_internal._disable_emit_hooks():
168199
ipex.nn.utils._model_convert.replace_dropout_with_identity(model)
169200
model = ipex.quantization.prepare(
170-
model, qconfig, x, inplace=prepare_inplace
201+
model, qconfig, x, inplace=prepare_inplace, example_kwarg_inputs=x_kwarg
171202
)
172203
# do calibration
173-
y = model(*x)
204+
y = self.model_forward_helper(model, x, x_kwarg)
174205
# jit trace to insert quant/dequant
206+
207+
def jit_trace_helper(convert_model, x, x_kwarg):
208+
if x_kwarg is None:
209+
return torch.jit.trace(convert_model, x)
210+
elif x is None:
211+
return torch.jit.trace(convert_model, example_kwarg_inputs=x_kwarg)
212+
else:
213+
raise AssertionError(
214+
"Can't set x and x_kwarg at same time for jit trace."
215+
)
216+
175217
if int8_bf16:
176218
with torch.cpu.amp.autocast():
177219
convert_model = ipex.quantization.convert(
178220
model, inplace=convert_inplace
179221
)
180-
traced_model = torch.jit.trace(convert_model, x)
222+
traced_model = jit_trace_helper(convert_model, x, x_kwarg)
181223
else:
182224
convert_model = ipex.quantization.convert(
183225
model, inplace=convert_inplace
184226
)
185-
traced_model = torch.jit.trace(convert_model, x)
227+
traced_model = jit_trace_helper(convert_model, x, x_kwarg)
186228
if freeze:
187229
traced_model = torch.jit.freeze(traced_model)
188230

189231
# warm up run
190-
y0 = traced_model(*x)
232+
y0 = self.model_forward_helper(traced_model, x, x_kwarg)
191233
# get the graph at the second run after freezing
192-
graph = traced_model.graph_for(*x)
234+
if x_kwarg is None:
235+
graph = traced_model.graph_for(*x)
236+
elif x is None:
237+
graph = traced_model.graph_for(**x_kwarg)
238+
else:
239+
raise AssertionError("Can't set x and x_kwarg at same time")
193240
return graph, traced_model, fp32_model
194241

195242
def checkPatterns(self, graph, patterns):

0 commit comments

Comments
 (0)