@@ -108,7 +108,6 @@ def insert_observer(
108
108
observer_name = get_new_observer_name (model )
109
109
setattr (model , observer_name , observer )
110
110
# put observer instance activation_post_process map
111
- assert activation_post_process_map is not None
112
111
activation_post_process_map [node .name ].append (observer_name )
113
112
# initialize index map for activation_post_process
114
113
if node .name not in activation_post_process_indexes :
@@ -154,7 +153,7 @@ def maybe_insert_observer_for_special_module(
154
153
observed_standalone_module = \
155
154
prepare (standalone_module , sm_qconfig_dict , sm_prepare_config_dict )
156
155
standalone_module_input_idxs = observed_standalone_module .\
157
- _standalone_module_input_quantized_idxs .int ().tolist ()
156
+ _standalone_module_input_quantized_idxs .int ().tolist () # type: ignore
158
157
observed_standalone_module = ObservedStandaloneGraphModule (
159
158
observed_standalone_module , observed_standalone_module .graph )
160
159
parent_name , name = _parent_name (node .target )
@@ -210,15 +209,14 @@ def insert_observer_for_output_of_the_node(
210
209
inserted_observer = True
211
210
elif (isinstance (quantize_handler ,
212
211
FixedQParamsOpQuantizeHandler ) and
213
- not model .training ) or \
214
- isinstance (quantize_handler , CopyNodeQuantizeHandler ):
212
+ not model .training ):
215
213
# inserting observers for output of observed module, or
216
214
# mark the output as observed
217
215
assert node .op in [
218
216
'call_module' ,
219
217
'call_function' ,
220
218
'call_method' ], \
221
- 'CopyNodeQuantizeHandler of type ' + node .op + ' is not handled'
219
+ 'FixedQParamsQuantizeHandler of type ' + node .op + ' is not handled'
222
220
223
221
def is_observed (input_arg ):
224
222
if isinstance (input_arg , Node ):
@@ -327,6 +325,80 @@ def insert_observer_for_input_arg_of_observed_node(
327
325
activation_post_process_indexes ,
328
326
env , observed_graph , load_arg , observed_node_names_set , quants )
329
327
328
+ def handle_copy_nodes (
329
+ observed_graph : Graph , matches : Dict [str , MatchResult ],
330
+ quants : Dict [str , List [Tuple [DefaultQuantizeHandler , Callable ]]],
331
+ qconfig_map : Dict [str , QConfigAny ],
332
+ activation_post_process_map : Dict [str , List [str ]],
333
+ modules : Dict [str , torch .nn .Module ]):
334
+ # map from node name to whether it is observed or not
335
+ observed_nodes : Set [Node ] = set ()
336
+ copy_nodes : Set [Node ] = set ()
337
+ non_tensor_input_binary_op_nodes : Set [Node ] = set ()
338
+ app_to_remove : Set [Node ] = set ()
339
+ env : Dict [Any , Any ] = {}
340
+
341
+ def load_arg (a : Argument ) -> Argument :
342
+ return map_arg (a , lambda node : env [node .name ])
343
+
344
+ def in_nodes (a : Argument , nodes : Set [Node ]) -> bool :
345
+ if isinstance (a , Node ):
346
+ return a in nodes
347
+ elif isinstance (a , list ) or isinstance (a , tuple ):
348
+ return all ([in_nodes (arg , nodes ) for arg in a ])
349
+ return False
350
+
351
+ result_graph = Graph ()
352
+ cache_for_no_tensor_check : Dict [Node , bool ] = dict ()
353
+ for node in observed_graph .nodes :
354
+ root_node , matched_nodes , pattern , quantize_handler , qconfig = matches .get (
355
+ node .name , (None , None , None , None , None ))
356
+
357
+ if node .op == "call_module" and is_activation_post_process (modules [node .target ]):
358
+ # rule 1: if the input of a copy node is observed, we won't need to
359
+ # insert observer for the output of copy node
360
+ if in_nodes (node .args [0 ], copy_nodes ) and in_nodes (node .args [0 ], observed_nodes ):
361
+ # we'll remove the activation_post_process if the previous node is
362
+ # an observed copy node
363
+ app_to_remove .add (node )
364
+
365
+ # rule 2: if the previous node is a binary op without tensor input, we can remove the observer
366
+ if in_nodes (node .args [0 ], non_tensor_input_binary_op_nodes ):
367
+ app_to_remove .add (node )
368
+ observed_nodes .add (node )
369
+
370
+ if root_node is node and qconfig is not None :
371
+ if isinstance (quantize_handler , CopyNodeQuantizeHandler ):
372
+ copy_nodes .add (node )
373
+ # if previous node is observed, the copy node will be observed as well
374
+ if in_nodes (node .args [0 ], observed_nodes ):
375
+ observed_nodes .add (node )
376
+ if all_node_args_have_no_tensors (node , modules , cache_for_no_tensor_check ):
377
+ non_tensor_input_binary_op_nodes .add (node )
378
+
379
+ # rule 3: for special node, we'll just remove observer for its input
380
+ special_nodes = [
381
+ ("call_function" , operator .getitem ),
382
+ ]
383
+ if (node .op , node .target ) in special_nodes :
384
+ if in_nodes (node .args [0 ], observed_nodes ):
385
+ prev_node = node .args [0 ].args [0 ]
386
+ if prev_node .name not in qconfig_map or qconfig_map [prev_node .name ] is None :
387
+ app_to_remove .add (node .args [0 ])
388
+ # if the previous node is not quantized, remove node from copy nodes
389
+ if node in copy_nodes :
390
+ copy_nodes .remove (node )
391
+
392
+ for node in observed_graph .nodes :
393
+ if node .op == "output" :
394
+ result_graph .output (map_arg (node .args [0 ], load_arg ))
395
+ elif node in app_to_remove :
396
+ env [node .name ] = env [node .args [0 ].name ]
397
+ else :
398
+ env [node .name ] = result_graph .node_copy (node , load_arg )
399
+
400
+ return result_graph
401
+
330
402
# A dictionary for querying the weight index for a given op
331
403
WEIGHT_INDEX_DICT = {
332
404
torch .nn .functional .conv1d : [1 ],
@@ -376,16 +448,15 @@ class Quantizer:
376
448
def __init__ (self ):
377
449
# mapping from matched node to full qualified path of activation_post_process
378
450
# must be filled before convert
379
- self .activation_post_process_map : Optional [
380
- Dict [str , List [str ]]] = None
451
+ self .activation_post_process_map : Dict [str , List [str ]] = {}
381
452
382
453
# mapping from matched node to the index of activation_post_process that we are
383
454
# using currently
384
455
self .activation_post_process_indexes : Dict [str , int ] = {}
385
456
386
457
# mapping from node name to qconfig that should be used for that node
387
458
# filled out for a model during _generate_qconfig_map
388
- self .qconfig_map : Optional [ Dict [str , QConfigAny ]] = None
459
+ self .qconfig_map : Dict [str , QConfigAny ] = {}
389
460
# mapping from fully qualified module name to module instance
390
461
# for example,
391
462
# {
@@ -504,7 +575,7 @@ def _prepare(
504
575
505
576
self .modules = dict (model .named_modules ())
506
577
507
- # map from node name to qconfig, used in _find_matches
578
+ # fill self.qconfig_map, a map from node name to qconfig, used in _find_matches
508
579
self ._generate_qconfig_map (model , model .graph , qconfig_dict , node_name_to_scope )
509
580
510
581
# match the patterns that will get quantized
@@ -526,7 +597,7 @@ def _prepare(
526
597
# have to be quantized, which requires measuring stats,
527
598
# initialize an DefaultQuantizeHandler object for each
528
599
quants : Dict [str , List [Tuple [DefaultQuantizeHandler , Callable ]]] = \
529
- self ._find_quants (model .graph , matches )
600
+ self ._find_quants (model .graph , self . modules , matches )
530
601
531
602
self .activation_post_process_map = defaultdict (list )
532
603
env : Dict [Any , Any ] = {}
@@ -619,6 +690,17 @@ def load_arg(a):
619
690
env ,
620
691
observed_graph , load_arg )
621
692
693
+ self .modules = dict (model .named_modules ())
694
+
695
+ # TODO: refactor this to a separate function
696
+ matches = self ._find_matches (
697
+ observed_graph , self .modules , self .patterns , standalone_module_names ,
698
+ standalone_module_classes , custom_module_classes )
699
+ quants = self ._find_quants (observed_graph , self .modules , matches )
700
+
701
+ observed_graph = handle_copy_nodes (
702
+ observed_graph , matches , quants , self .qconfig_map ,
703
+ self .activation_post_process_map , self .modules )
622
704
623
705
self .save_state (model )
624
706
model = ObservedGraphModule (model , observed_graph )
@@ -726,7 +808,7 @@ def _convert(self, model: GraphModule, is_reference: bool = False,
726
808
custom_module_classes = custom_module_classes )
727
809
728
810
quants : Dict [str , List [Tuple [DefaultQuantizeHandler , Callable ]]] = \
729
- self ._find_quants (model .graph , matches )
811
+ self ._find_quants (model .graph , self . modules , matches )
730
812
731
813
self .quantized_graph = Graph ()
732
814
env : Dict [str , Node ] = {}
@@ -845,7 +927,9 @@ def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
845
927
quantized = True
846
928
847
929
# Need to get correct quantized/non-quantized state forn the output
848
- # of CopyNodeQuantizeHandler
930
+ # of FixedQParamsQuantizeHandler
931
+ # TODO: we may want to try to remove the special case here
932
+ # as well
849
933
if type (obj ) in [
850
934
CopyNodeQuantizeHandler ,
851
935
FixedQParamsOpQuantizeHandler
@@ -854,14 +938,14 @@ def is_output_quantized(node: Node, obj: QuantizeHandler) -> bool:
854
938
'call_module' ,
855
939
'call_function' ,
856
940
'call_method' ], \
857
- 'CopyNodeQuantizeHandler of type ' + node .op + ' is not handled'
941
+ 'FixedQParamsQuantizeHandler of type ' + node .op + ' is not handled'
858
942
# TODO: need to extend this to consider all relevant args instead of just arg[0]
859
943
quantized = node_arg_is_quantized (node .args [0 ])
860
944
861
945
# the output is unquantized if the node is not a CopyNode
862
946
# and activation is fp16 (since we will output fp32 currently for fp16
863
947
# converter
864
- if ( not isinstance ( obj , CopyNodeQuantizeHandler ) and not activation_is_int8_quantized (qconfig ) ) or \
948
+ if not activation_is_int8_quantized (qconfig ) or \
865
949
not input_output_observed (obj ):
866
950
quantized = False
867
951
if node_return_type_is_int (node ):
@@ -1155,14 +1239,14 @@ def record_match(pattern, node, matched):
1155
1239
else :
1156
1240
matched .append (node )
1157
1241
1158
- assert self . qconfig_map is not None
1242
+ cache_for_no_tensor_check : Dict [ Node , bool ] = dict ()
1159
1243
for node in reversed (graph .nodes ):
1160
1244
if node .name not in match_map and node .name not in all_matched :
1161
1245
for pattern , value in patterns .items ():
1162
1246
if is_match (modules , node , pattern ):
1163
1247
skip_this_match = False
1164
1248
if value is BinaryOpQuantizeHandler :
1165
- use_copy_node = all_node_args_have_no_tensors (node )
1249
+ use_copy_node = all_node_args_have_no_tensors (node , modules , cache_for_no_tensor_check )
1166
1250
if use_copy_node :
1167
1251
# TODO(future PR): update the pattern to quantize
1168
1252
# handler logic to take this into account.
@@ -1220,14 +1304,16 @@ def is_standalone_module(node_target):
1220
1304
1221
1305
return match_map
1222
1306
1223
- def _find_quants (self , graph : Graph , matches : Dict [str , MatchResult ],
1224
- ) -> Dict [str , List [Tuple [DefaultQuantizeHandler , Callable ]]]:
1307
+ def _find_quants (
1308
+ self , graph : Graph , modules : Dict [str , torch .nn .Module ],
1309
+ matches : Dict [str , MatchResult ]) -> Dict [str , List [Tuple [DefaultQuantizeHandler , Callable ]]]:
1225
1310
"""
1226
1311
Takes the nodes in the input graph and pending matches, and finds and
1227
1312
returns the input and output nodes which need to be quantized.
1228
1313
1229
1314
Inputs:
1230
1315
- graph: an fx.Graph object
1316
+ - modules: a dictionary from module path to module
1231
1317
- matches: output of self._find_matches function
1232
1318
1233
1319
Outputs a map of
@@ -1241,13 +1327,14 @@ def _find_quants(self, graph: Graph, matches: Dict[str, MatchResult],
1241
1327
int8 and then float16
1242
1328
"""
1243
1329
quants : Dict [str , List [Tuple [DefaultQuantizeHandler , Callable ]]] = defaultdict (list )
1330
+ cache_for_no_tensor_check : Dict [Node , bool ] = dict ()
1244
1331
1245
1332
def visit (node , matched_pattern , qconfig ):
1246
1333
def visit_arg (arg ):
1247
1334
is_weight = node_arg_is_weight (node , arg )
1248
1335
is_bias = node_arg_is_bias (node , arg )
1249
1336
is_activation = not (is_weight or is_bias )
1250
- no_tensors = all_node_args_have_no_tensors (arg )
1337
+ no_tensors = all_node_args_have_no_tensors (arg , modules , cache_for_no_tensor_check )
1251
1338
# bias needs to be quantized if activation is fp16 and weight is fp16
1252
1339
# this is the case for glow
1253
1340
should_add_handler = qconfig is not None and (
0 commit comments