Skip to content

Commit ff231fb

Browse files
quantization: support dynamic linear and lstm (#787)
1 parent 940f189 commit ff231fb

File tree

7 files changed

+186
-15
lines changed

7 files changed

+186
-15
lines changed

intel_extension_for_pytorch/ao/quantization/README.md

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ for data in calibration_data_set:
4646
# prepared_model.load_qconf_summary(qconf_summary = "configure.json")
4747
```
4848

49-
### Convert to Quantized Model and Deploy
49+
### Convert to Static Quantized Model and Deploy
5050

5151
```python
5252
# make sure the example_inputs's size is same as the real input's size
@@ -63,9 +63,46 @@ y = traced_model(x)
6363
# quantized_model = torch.jit.load("quantized_model.pt")
6464
# quantized_model = torch.jit.freeze(quantized_model.eval())
6565
# ...
66-
6766
```
6867

6968
## Dynamic Quantization
7069

71-
TODO(future PR):
70+
```python
71+
import intel_extension_for_pytorch as ipex
72+
from intel_extension_for_pytorch.quantization import prepare, convert
73+
```
74+
75+
### Define QConfig
76+
77+
```python
78+
from torch.ao.quantization import MinMaxObserver, PlaceholderObserver, QConfig
79+
dynamic_qconfig = QConfig(
80+
activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
81+
weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
82+
```
83+
84+
Note: For weight observer, it only support dtype **torch.qint8**, and the qscheme can be **torch.per_tensor_symmetric** or **torch.per_tensor_symmetric**.
85+
86+
### Prepare Model
87+
88+
```python
89+
prepared_model = prepare(user_model, qconfig, example_inputs=example_inputs, inplace=False)
90+
```
91+
92+
## Convert to Dynamic Quantized Model and Deploy
93+
94+
```python
95+
# make sure the example_inputs's size is same as the real input's size
96+
convert_model = convert(prepared_model)
97+
with torch.no_grad():
98+
traced_model = torch.jit.trace(convert_model, example_input)
99+
traced_model = torch.jit.freeze(traced_model)
100+
# for inference
101+
y = traced_model(x)
102+
103+
# or save the model to deploy
104+
# traced_model.save("quantized_model.pt")
105+
# quantized_model = torch.jit.load("quantized_model.pt")
106+
# quantized_model = torch.jit.freeze(quantized_model.eval())
107+
# ...
108+
```
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Dict, Callable, Any, Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from torch.ao.quantization import swap_module
7+
import torch.nn.quantized.dynamic as nnqd
8+
9+
10+
# Default map for swapping dynamic modules
11+
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
12+
nn.Linear: nnqd.Linear,
13+
nn.LSTM: nnqd.LSTM,
14+
# TODO: support more RNN module
15+
#nn.GRUCell: nnqd.GRUCell,
16+
#nn.GRU: nnqd.GRU,
17+
#nn.LSTMCell: nnqd.LSTMCell,
18+
#nn.RNNCell: nnqd.RNNCell,
19+
}
20+
21+
def _get_qconfig_dtypes(qconfig):
22+
r"""
23+
Returns the qconfig tuple for qconfig:
24+
(activation_dtype, weight_dtype, activation_compute_dtype)
25+
"""
26+
assert qconfig is not None
27+
activation = qconfig.activation()
28+
weight = qconfig.weight()
29+
compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None
30+
return (activation.dtype, weight.dtype, compute_dtype)
31+
32+
def _op_is_int8_dynamically_quantized(qconfig) -> bool:
33+
r"""
34+
Given a qconfig, returns True if this op is using int8 dynamic
35+
quantization
36+
"""
37+
activation_dtype, weight_dtype, activation_compute_dtype = \
38+
_get_qconfig_dtypes(qconfig)
39+
return (
40+
activation_dtype is torch.float and
41+
# for now, the lines below assume fbgemm or qnnpack
42+
weight_dtype is torch.qint8 and
43+
activation_compute_dtype is torch.quint8
44+
)
45+
46+
47+
def swap_child_modules(
48+
module: torch.nn.Module,
49+
dynamic_mappings: Dict[Callable, Any] = DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
50+
parent_fqn: Optional[str] = None,
51+
) -> None:
52+
"""
53+
For each direct child of `module`, swaps it using `dyanamic_mappings`
54+
if the qconfig for that child is using int8 dynamic quantization,
55+
and the module type is in the mapping.
56+
Recursively calls itself on each child.
57+
"""
58+
59+
if hasattr(module, '_auto_quant_state'):
60+
qstate = module._auto_quant_state
61+
for _, qopinfo in qstate.idx_to_seen_q_op_infos.items():
62+
qconfig = qopinfo.qconfig
63+
if not qconfig:
64+
continue
65+
fqn = qopinfo.fqn
66+
if not fqn:
67+
continue
68+
op_int8_dynamically_quantized = _op_is_int8_dynamically_quantized(qconfig)
69+
70+
if op_int8_dynamically_quantized:
71+
mod = module._modules[fqn]
72+
if not type(mod) in dynamic_mappings:
73+
continue
74+
mod.qconfig = qconfig
75+
module._modules[fqn] = swap_module(mod, dynamic_mappings, {})
76+
77+
for _, child in module.named_children():
78+
swap_child_modules(child)

intel_extension_for_pytorch/ao/quantization/_quantization_state.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,10 @@ def op_weight_convert_before_hook(
416416
if op.bias:
417417
new_args.append(weights[tensor_arg_idx + 2])
418418
new_args.append(weights[tensor_arg_idx + 3])
419+
else:
420+
for s in range(step):
421+
new_args.append(weights[tensor_arg_idx + s])
422+
419423
return new_args
420424

421425
def op_convert_after_hook(
@@ -713,7 +717,8 @@ def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo):
713717
# always add observer if the op can be quantized.
714718
tensor_id = tensor_info.id # type: ignore[attr-defined]
715719
weight_arg_idx = get_weight_arg_idx(seen_q_op_info.type)
716-
if idx == weight_arg_idx:
720+
# avoid add weight observer for dynamic quantization.
721+
if idx == weight_arg_idx and not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver):
717722
# conv_transpose weight is iohw or iodhw, so we change the observer axis to 1.
718723
if seen_q_op_info.type in [str(F.conv_transpose2d), str(F.conv_transpose3d)] and \
719724
isinstance(qconfig.weight(), torch.ao.quantization.PerChannelMinMaxObserver):
@@ -736,17 +741,18 @@ def _maybe_insert_input_observers(self, seen_q_op_info: SeenQOpInfo):
736741
tensor_id = tensor_info.id # type: ignore[attr-defined]
737742
if seen_q_op_info.type == str(torch.nn.EmbeddingBag):
738743
obs = qconfig.activation()
739-
else:
744+
self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs
745+
elif not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver):
740746
if seen_q_op_info.type in [str(torch.nn.ConvTranspose2d), str(torch.nn.ConvTranspose3d)] and \
741747
isinstance(qconfig.weight(), torch.ao.quantization.PerChannelMinMaxObserver):
742748
obs = qconfig.weight.with_args(ch_axis=1)()
743749
else:
744750
obs = qconfig.weight()
745-
self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs
751+
self.weight_tensor_id_to_observer[str(seen_q_op_info.idx) + "_" + str(tensor_id)] = obs
746752
# LSTM, we don't know whether has bais or not, so we add observer for all them, but will not use them at convert step.
747753
# w_ih, w_hh share same observe, and b_ih, b_hh also share same observer
748754
if seen_q_op_info.type == str(torch.nn.LSTM):
749-
if qconfig is not None:
755+
if qconfig is not None and not isinstance(qconfig.activation(), torch.ao.quantization.PlaceholderObserver):
750756
for i in range(0, len(seen_q_op_info.weight_tensor_infos), 2):
751757
tensor_id = seen_q_op_info.weight_tensor_infos[i].id
752758
obs = qconfig.weight()

intel_extension_for_pytorch/ao/quantization/_quantization_state_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.nn as nn
55
import torch.nn.functional as F
6+
import torch.nn.quantized.dynamic as nnqd
67
from intel_extension_for_pytorch.nn.functional import interaction
78
import intel_extension_for_pytorch._C as core
89

@@ -63,14 +64,23 @@
6364
torch.nn.EmbeddingBag,
6465
torch.nn.Flatten,
6566
torch.nn.LSTM,
67+
# dynamic quantization module
68+
nnqd.Linear,
69+
nnqd.LSTM,
6670
])
6771

6872
may_inplace_module = set([
6973
torch.nn.ReLU,
7074
])
7175

72-
binary_related_ops = (
76+
77+
a_related_to_b = (
7378
(str(torch.add), str(torch.Tensor.add)),
79+
(str(torch.Tensor.add), str(torch.add)),
80+
(str(torch.nn.Linear), str(nnqd.Linear)),
81+
(str(nnqd.Linear), str(torch.nn.Linear)),
82+
(str(torch.nn.LSTM), str(nnqd.LSTM)),
83+
(str(nnqd.LSTM), str(torch.nn.LSTM)),
7484
)
7585

7686
conv_linear_ops = [
@@ -123,7 +133,7 @@ def ops_are_related(
123133
if type_is_module:
124134
cur_op = type(cur_op)
125135
return str(cur_op) == expected_op_type or \
126-
(str(cur_op), expected_op_type) in binary_related_ops
136+
(str(cur_op), expected_op_type) in a_related_to_b
127137

128138
def _raise_obs_not_found_error(func):
129139
raise RuntimeError(

intel_extension_for_pytorch/ao/quantization/_quantize_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
sync_pool_input_output_scale_zp, module_call_to_function_call, quantized_modules_has_weights, load_qconf_summary_to_model
1111
from ._quantization_state import AutoQuantizationState, AutoQuantizationStateModuleDict, init_model_quant_state
1212
from ._recipe import get_defaut_recipe
13-
13+
from ._module_swap_utils import swap_child_modules
1414

1515
# AutoQuantizationState lives in parent module's _modules.
1616
# Currently, `torch.nn.Sequential`'s forward iterates over all
@@ -540,7 +540,8 @@ def unwrap_proxy(a):
540540
for _, v in module._fqn_to_auto_quant_state_map.items():
541541
v.tensor_id_to_observer.clear()
542542
v.weight_tensor_id_to_observer.clear()
543-
# Attach quan_info to parent each module
543+
# Attach quant_info to parent each module
544544
attach_op_convert_info_to_model(module)
545+
swap_child_modules(module)
545546
module.__class__ = QuantizationDispatchModule
546547
return module

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,9 @@ bool checkQuantization(Block* block) {
455455

456456
if (node->kind() == Symbol::aten("quantize_per_tensor") ||
457457
node->kind() == Symbol::aten("dequantize") ||
458-
node->kind() == Symbol::aten("quantize_per_channel")) {
458+
node->kind() == Symbol::aten("quantize_per_channel") ||
459+
node->kind() == Symbol::aten("quantized_lstm") ||
460+
node->kind() == Symbol::fromQualString("quantized::linear_dynamic")) {
459461
return true;
460462
}
461463
}
@@ -476,11 +478,11 @@ void FusionPass(std::shared_ptr<Graph>& graph) {
476478
// remove BailOut and BailoutTemplate
477479
RemoveBailOutNodesAndSpecializeTypes(graph->block());
478480
RemoveBailoutTemplateNodes(graph->block());
479-
480481
// LLGA fusion pass for int8
481482
GRAPH_DUMP(
482483
"After RemoveProfileNodesAndSpecializeTypes. Before LLGA fusion pass",
483484
graph);
485+
484486
if (isQuantized(graph) || torch_ipex::autocast::is_llga_fp32_bf16_enabled()) {
485487
RemoveRedundantAliases(graph);
486488
fuser::onednn::fuseGraph(graph);

tests/cpu/test_ao_jit_ipex_quantization.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from torch.testing._internal.common_utils import TEST_SCIPY, TemporaryFileName
1616

1717
import intel_extension_for_pytorch as ipex
18-
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, HistogramObserver, QConfig
18+
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, HistogramObserver, \
19+
QConfig, PlaceholderObserver
1920

2021
default_weight_observer = PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
2122

@@ -34,6 +35,9 @@
3435
weight = default_weight_observer),
3536
]
3637

38+
dynamic_qconfig = QConfig(
39+
activation = PlaceholderObserver.with_args(dtype=torch.float, compute_dtype=torch.quint8),
40+
weight = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))
3741

3842
class TestIpexOps(JitLlgaTestCase):
3943
def test_adaptive_avg_pool2d(self):
@@ -304,6 +308,39 @@ def forward(self, x):
304308
graph, _, _ = self.prepareModel(m, [x])
305309
FileCheck().check_not("aten::mul_").check("aten::mul").run(graph)
306310

311+
class TestDynamicQuantization(JitLlgaTestCase):
312+
def test_linear_dynamic(self):
313+
class M(nn.Module):
314+
def __init__(self):
315+
super(M, self).__init__()
316+
self.linear = torch.nn.Linear(3, 3)
317+
318+
def forward(self, x):
319+
x = self.linear(x)
320+
return x
321+
322+
m = M().eval()
323+
x = torch.randn(1, 3)
324+
graph = self.checkQuantizeTrace(m, [x], atol=2e-1, qconfig=dynamic_qconfig)
325+
FileCheck().check_not("aten:linear").check("quantized::linear_dynamic").run(graph)
326+
327+
def test_lstm_dynamic(self):
328+
class M(nn.Module):
329+
def __init__(self):
330+
super(M, self).__init__()
331+
self.lstm = torch.nn.LSTM(10, 20, 2)
332+
333+
def forward(self, x, hx, cx):
334+
x, h_xs = self.lstm(x, (hx, cx))
335+
return x, h_xs
336+
337+
m = M().eval()
338+
x = torch.randn(5, 3, 10)
339+
h = torch.randn(2, 3, 20)
340+
c = torch.randn(2, 3, 20)
341+
graph = self.checkQuantizeTrace(m, [x, h, c], atol=2e-1, qconfig=dynamic_qconfig)
342+
FileCheck().check_not("aten:lstm").check("aten::quantized_lstm").run(graph)
343+
307344

308345
if __name__ == '__main__':
309-
run_tests()
346+
run_tests()

0 commit comments

Comments
 (0)