Skip to content

Commit 4d449f9

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode][fx] Separate handling Copy operator to a helper function (#54644) (#55429)
Summary: Pull Request resolved: #55429 Previously we special case copy operator in normal insert observer code, this PR tries to split the special case logic to a separate function and keep the rest of the code clean. Test Plan: Imported from OSS Imported from OSS Reviewed By: vkuzo Differential Revision: D27609972 fbshipit-source-id: 378f6aa70f18c0b477b62b6efe236648748aae7e
1 parent 4248696 commit 4d449f9

File tree

4 files changed

+157
-58
lines changed

4 files changed

+157
-58
lines changed

test/quantization/test_quantize_fx.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def forward(self, x):
645645

646646
dict_input = {"input": torch.randn(1, 1, 1, 1)}
647647
m = M().eval()
648-
qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
648+
qconfig_dict = {"": default_qconfig}
649649
m = prepare_fx(m, qconfig_dict)
650650
m(dict_input)
651651
m = convert_fx(m)
@@ -2296,15 +2296,15 @@ def forward(self, x):
22962296
model = FuncLinear(use_bias, has_relu, f_relu)
22972297
linear_fun = ns.call_function(torch.nn.functional.linear)
22982298
prepare_node_occurrence = {
2299-
# activation, weight, bias, output
2300-
ns.call_module(torch.quantization.PlaceholderObserver): 4 if use_bias else 3
2299+
# activation, weight, bias and output
2300+
ns.call_module(torch.quantization.PlaceholderObserver): 3 + int(use_bias)
23012301
}
23022302
convert_node_occurrence = {
23032303
# we don't support static fp16 ops, so the linear functino
23042304
# is unfused
23052305
linear_fun: 1,
2306-
# activation, weight, bias, output
2307-
ns.call_method("to"): 4 if use_bias else 3
2306+
# activation, weight, bias and output
2307+
ns.call_method("to"): 3 + int(use_bias)
23082308
}
23092309
self.checkGraphModeFxOp(
23102310
model, data, QuantType.DYNAMIC, linear_fun,
@@ -3643,12 +3643,8 @@ def forward(self, x):
36433643
# make sure it runs
36443644
m = convert_fx(m)
36453645
expected_occurrence = {
3646-
# we have extra quant/dequant after reshape since currently we do not
3647-
# propagate the information about the dtype of the output
3648-
# of CopyNode, we may improve this later and remove the
3649-
# extra quant/dequant
36503646
ns.call_function(torch.quantize_per_tensor): 2,
3651-
ns.call_method("dequantize"): 3,
3647+
ns.call_method("dequantize"): 2,
36523648
ns.call_method("to"): 1,
36533649
ns.call_function(torch.ops.quantized.linear): 2
36543650
}

torch/quantization/fx/quantization_patterns.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,10 @@ def __init__(self, quantizer: QuantizerCls, node: Node):
137137
# determine how many of the first two args are Tensors (versus scalars)
138138
# this distinguishes things like "x + y" from "x + 2" or "2 + x"
139139
self.num_tensor_args = 0
140+
cache_for_no_tensor_check: Dict[Node, bool] = dict()
140141
for arg_idx in range(len(self.binary_op_node.args)):
141142
arg = self.binary_op_node.args[arg_idx]
142-
if isinstance(arg, Node) and (not all_node_args_have_no_tensors(arg)):
143+
if isinstance(arg, Node) and (not all_node_args_have_no_tensors(arg, quantizer.modules, cache_for_no_tensor_check)):
143144
self.num_tensor_args += 1
144145
self.all_node_args_are_tensors = \
145146
(self.num_tensor_args == len(self.binary_op_node.args))
@@ -190,7 +191,10 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
190191
if self.num_tensor_args == 1:
191192
# add/mul scalar
192193
first_arg = self.binary_op_node.args[0]
193-
if isinstance(first_arg, Node) and (not all_node_args_have_no_tensors(first_arg)):
194+
cache_for_no_tensor_check: Dict[Node, bool] = dict()
195+
if isinstance(first_arg, Node) and (
196+
not all_node_args_have_no_tensors(
197+
first_arg, quantizer.modules, cache_for_no_tensor_check)):
194198
quantized_index = 0
195199
else:
196200
quantized_index = 1
@@ -958,8 +962,8 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
958962
@register_quant_pattern(torch.squeeze)
959963
@register_quant_pattern(torch.stack)
960964
@register_quant_pattern(torch.unsqueeze)
961-
@register_quant_pattern(operator.getitem)
962965
@register_quant_pattern(operator.floordiv)
966+
@register_quant_pattern(operator.getitem)
963967
@register_quant_pattern('chunk')
964968
@register_quant_pattern('clamp')
965969
@register_quant_pattern('contiguous')

torch/quantization/fx/quantize.py

Lines changed: 106 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def insert_observer(
108108
observer_name = get_new_observer_name(model)
109109
setattr(model, observer_name, observer)
110110
# put observer instance activation_post_process map
111-
assert activation_post_process_map is not None
112111
activation_post_process_map[node.name].append(observer_name)
113112
# initialize index map for activation_post_process
114113
if node.name not in activation_post_process_indexes:
@@ -154,7 +153,7 @@ def maybe_insert_observer_for_special_module(
154153
observed_standalone_module = \
155154
prepare(standalone_module, sm_qconfig_dict, sm_prepare_config_dict)
156155
standalone_module_input_idxs = observed_standalone_module.\
157-
_standalone_module_input_quantized_idxs.int().tolist()
156+
_standalone_module_input_quantized_idxs.int().tolist() # type: ignore
158157
observed_standalone_module = ObservedStandaloneGraphModule(
159158
observed_standalone_module, observed_standalone_module.graph)
160159
parent_name, name = _parent_name(node.target)
@@ -210,15 +209,14 @@ def insert_observer_for_output_of_the_node(
210209
inserted_observer = True
211210
elif (isinstance(quantize_handler,
212211
FixedQParamsOpQuantizeHandler) and
213-
not model.training) or \
214-
isinstance(quantize_handler, CopyNodeQuantizeHandler):
212+
not model.training):
215213
# inserting observers for output of observed module, or
216214
# mark the output as observed
217215
assert node.op in [
218216
'call_module',
219217
'call_function',
220218
'call_method'], \
221-
'CopyNodeQuantizeHandler of type ' + node.op + ' is not handled'
219+
'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
222220

223221
def is_observed(input_arg):
224222
if isinstance(input_arg, Node):
@@ -327,6 +325,80 @@ def insert_observer_for_input_arg_of_observed_node(
327325
activation_post_process_indexes,
328326
env, observed_graph, load_arg, observed_node_names_set, quants)
329327

328+
def handle_copy_nodes(
329+
observed_graph: Graph, matches: Dict[str, MatchResult],
330+
quants: Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]],
331+
qconfig_map: Dict[str, QConfigAny],
332+
activation_post_process_map: Dict[str, List[str]],
333+
modules: Dict[str, torch.nn.Module]):
334+
# map from node name to whether it is observed or not
335+
observed_nodes: Set[Node] = set()
336+
copy_nodes: Set[Node] = set()
337+
non_tensor_input_binary_op_nodes: Set[Node] = set()
338+
app_to_remove: Set[Node] = set()
339+
env: Dict[Any, Any] = {}
340+
341+
def load_arg(a: Argument) -> Argument:
342+
return map_arg(a, lambda node: env[node.name])
343+
344+
def in_nodes(a: Argument, nodes: Set[Node]) -> bool:
345+
if isinstance(a, Node):
346+
return a in nodes
347+
elif isinstance(a, list) or isinstance(a, tuple):
348+
return all([in_nodes(arg, nodes) for arg in a])
349+
return False
350+
351+
result_graph = Graph()
352+
cache_for_no_tensor_check: Dict[Node, bool] = dict()
353+
for node in observed_graph.nodes:
354+
root_node, matched_nodes, pattern, quantize_handler, qconfig = matches.get(
355+
node.name, (None, None, None, None, None))
356+
357+
if node.op == "call_module" and is_activation_post_process(modules[node.target]):
358+
# rule 1: if the input of a copy node is observed, we won't need to
359+
# insert observer for the output of copy node
360+
if in_nodes(node.args[0], copy_nodes) and in_nodes(node.args[0], observed_nodes):
361+
# we'll remove the activation_post_process if the previous node is
362+
# an observed copy node
363+
app_to_remove.add(node)
364+
365+
# rule 2: if the previous node is a binary op without tensor input, we can remove the observer
366+
if in_nodes(node.args[0], non_tensor_input_binary_op_nodes):
367+
app_to_remove.add(node)
368+
observed_nodes.add(node)
369+
370+
if root_node is node and qconfig is not None:
371+
if isinstance(quantize_handler, CopyNodeQuantizeHandler):
372+
copy_nodes.add(node)
373+
# if previous node is observed, the copy node will be observed as well
374+
if in_nodes(node.args[0], observed_nodes):
375+
observed_nodes.add(node)
376+
if all_node_args_have_no_tensors(node, modules, cache_for_no_tensor_check):
377+
non_tensor_input_binary_op_nodes.add(node)
378+
379+
# rule 3: for special node, we'll just remove observer for its input
380+
special_nodes = [
381+
("call_function", operator.getitem),
382+
]
383+
if (node.op, node.target) in special_nodes:
384+
if in_nodes(node.args[0], observed_nodes):
385+
prev_node = node.args[0].args[0]
386+
if prev_node.name not in qconfig_map or qconfig_map[prev_node.name] is None:
387+
app_to_remove.add(node.args[0])
388+
# if the previous node is not quantized, remove node from copy nodes
389+
if node in copy_nodes:
390+
copy_nodes.remove(node)
391+
392+
for node in observed_graph.nodes:
393+
if node.op == "output":
394+
result_graph.output(map_arg(node.args[0], load_arg))
395+
elif node in app_to_remove:
396+
env[node.name] = env[node.args[0].name]
397+
else:
398+
env[node.name] = result_graph.node_copy(node, load_arg)
399+
400+
return result_graph
401+
330402
# A dictionary for querying the weight index for a given op
331403
WEIGHT_INDEX_DICT = {
332404
torch.nn.functional.conv1d : [1],
@@ -376,16 +448,15 @@ class Quantizer:
376448
def __init__(self):
377449
# mapping from matched node to full qualified path of activation_post_process
378450
# must be filled before convert
379-
self.activation_post_process_map: Optional[
380-
Dict[str, List[str]]] = None
451+
self.activation_post_process_map: Dict[str, List[str]] = {}
381452

382453
# mapping from matched node to the index of activation_post_process that we are
383454
# using currently
384455
self.activation_post_process_indexes: Dict[str, int] = {}
385456

386457
# mapping from node name to qconfig that should be used for that node
387458
# filled out for a model during _generate_qconfig_map
388-
self.qconfig_map: Optional[Dict[str, QConfigAny]] = None
459+
self.qconfig_map: Dict[str, QConfigAny] = {}
389460
# mapping from fully qualified module name to module instance
390461
# for example,
391462
# {
@@ -504,7 +575,7 @@ def _prepare(
504575

505576
self.modules = dict(model.named_modules())
506577

507-
# map from node name to qconfig, used in _find_matches
578+
# fill self.qconfig_map, a map from node name to qconfig, used in _find_matches
508579
self._generate_qconfig_map(model, model.graph, qconfig_dict, node_name_to_scope)
509580

510581
# match the patterns that will get quantized
@@ -526,7 +597,7 @@ def _prepare(
526597
# have to be quantized, which requires measuring stats,
527598
# initialize an DefaultQuantizeHandler object for each
528599
quants: Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]] = \
529-
self._find_quants(model.graph, matches)
600+
self._find_quants(model.graph, self.modules, matches)
530601

531602
self.activation_post_process_map = defaultdict(list)
532603
env: Dict[Any, Any] = {}
@@ -619,6 +690,17 @@ def load_arg(a):
619690
env,
620691
observed_graph, load_arg)
621692

693+
self.modules = dict(model.named_modules())
694+
695+
# TODO: refactor this to a separate function
696+
matches = self._find_matches(
697+
observed_graph, self.modules, self.patterns, standalone_module_names,
698+
standalone_module_classes, custom_module_classes)
699+
quants = self._find_quants(observed_graph, self.modules, matches)
700+
701+
observed_graph = handle_copy_nodes(
702+
observed_graph, matches, quants, self.qconfig_map,
703+
self.activation_post_process_map, self.modules)
622704

623705
self.save_state(model)
624706
model = ObservedGraphModule(model, observed_graph)
@@ -726,7 +808,7 @@ def _convert(self, model: GraphModule, is_reference: bool = False,
726808
custom_module_classes=custom_module_classes)
727809

728810
quants: Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]] = \
729-
self._find_quants(model.graph, matches)
811+
self._find_quants(model.graph, self.modules, matches)
730812

731813
self.quantized_graph = Graph()
732814
env: Dict[str, Node] = {}
@@ -845,7 +927,9 @@ def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
845927
quantized = True
846928

847929
# Need to get correct quantized/non-quantized state forn the output
848-
# of CopyNodeQuantizeHandler
930+
# of FixedQParamsQuantizeHandler
931+
# TODO: we may want to try to remove the special case here
932+
# as well
849933
if type(obj) in [
850934
CopyNodeQuantizeHandler,
851935
FixedQParamsOpQuantizeHandler
@@ -854,14 +938,14 @@ def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
854938
'call_module',
855939
'call_function',
856940
'call_method'], \
857-
'CopyNodeQuantizeHandler of type ' + node.op + ' is not handled'
941+
'FixedQParamsQuantizeHandler of type ' + node.op + ' is not handled'
858942
# TODO: need to extend this to consider all relevant args instead of just arg[0]
859943
quantized = node_arg_is_quantized(node.args[0])
860944

861945
# the output is unquantized if the node is not a CopyNode
862946
# and activation is fp16 (since we will output fp32 currently for fp16
863947
# converter
864-
if (not isinstance(obj, CopyNodeQuantizeHandler) and not activation_is_int8_quantized(qconfig)) or \
948+
if not activation_is_int8_quantized(qconfig) or \
865949
not input_output_observed(obj):
866950
quantized = False
867951
if node_return_type_is_int(node):
@@ -1155,14 +1239,14 @@ def record_match(pattern, node, matched):
11551239
else:
11561240
matched.append(node)
11571241

1158-
assert self.qconfig_map is not None
1242+
cache_for_no_tensor_check: Dict[Node, bool] = dict()
11591243
for node in reversed(graph.nodes):
11601244
if node.name not in match_map and node.name not in all_matched:
11611245
for pattern, value in patterns.items():
11621246
if is_match(modules, node, pattern):
11631247
skip_this_match = False
11641248
if value is BinaryOpQuantizeHandler:
1165-
use_copy_node = all_node_args_have_no_tensors(node)
1249+
use_copy_node = all_node_args_have_no_tensors(node, modules, cache_for_no_tensor_check)
11661250
if use_copy_node:
11671251
# TODO(future PR): update the pattern to quantize
11681252
# handler logic to take this into account.
@@ -1220,14 +1304,16 @@ def is_standalone_module(node_target):
12201304

12211305
return match_map
12221306

1223-
def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
1224-
) -> Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]]:
1307+
def _find_quants(
1308+
self, graph: Graph, modules: Dict[str, torch.nn.Module],
1309+
matches: Dict[str, MatchResult]) -> Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]]:
12251310
"""
12261311
Takes the nodes in the input graph and pending matches, and finds and
12271312
returns the input and output nodes which need to be quantized.
12281313
12291314
Inputs:
12301315
- graph: an fx.Graph object
1316+
- modules: a dictionary from module path to module
12311317
- matches: output of self._find_matches function
12321318
12331319
Outputs a map of
@@ -1241,13 +1327,14 @@ def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
12411327
int8 and then float16
12421328
"""
12431329
quants: Dict[str, List[Tuple[DefaultQuantizeHandler, Callable]]] = defaultdict(list)
1330+
cache_for_no_tensor_check: Dict[Node, bool] = dict()
12441331

12451332
def visit(node, matched_pattern, qconfig):
12461333
def visit_arg(arg):
12471334
is_weight = node_arg_is_weight(node, arg)
12481335
is_bias = node_arg_is_bias(node, arg)
12491336
is_activation = not (is_weight or is_bias)
1250-
no_tensors = all_node_args_have_no_tensors(arg)
1337+
no_tensors = all_node_args_have_no_tensors(arg, modules, cache_for_no_tensor_check)
12511338
# bias needs to be quantized if activation is fp16 and weight is fp16
12521339
# this is the case for glow
12531340
should_add_handler = qconfig is not None and (

0 commit comments

Comments
 (0)