16
16
conv_gemm_ops = [str (F .conv2d ), str (nn .Conv2d ), str (F .conv3d ), str (nn .Conv3d ), str (torch .conv2d ), str (torch .conv3d ), \
17
17
str (F .conv_transpose2d ), str (torch .nn .ConvTranspose2d ), str (F .conv_transpose3d ), str (torch .nn .ConvTranspose3d ),
18
18
str (torch .conv_transpose2d ), str (torch .conv_transpose2d ), str (F .linear ), str (nn .Linear ), str (torch .matmul ), str (torch .Tensor .matmul )]
19
+ conv_ops = [str (F .conv2d ), str (nn .Conv2d ), str (F .conv3d ), str (nn .Conv3d ), str (torch .conv2d ), str (torch .conv3d ), \
20
+ str (F .conv_transpose2d ), str (torch .nn .ConvTranspose2d ), str (F .conv_transpose3d ), str (torch .nn .ConvTranspose3d ),
21
+ str (torch .conv_transpose2d ), str (torch .conv_transpose2d )]
19
22
rnn_ops = [str (torch .nn .LSTM )]
20
23
21
24
# Those ops only support s8->s8 path, and also require the qscheme is per_tensor_symmetric.
@@ -233,6 +236,7 @@ def reset_input_inf_dtype_to_orig_dtype(node, input_idx):
233
236
node .input_tensor_force_inf_dtype [input_idx ] = node .input_tensor_infos [input_idx ].inf_dtype
234
237
235
238
conv_gemm_node = _find_fused_node_with_cur_add (node , conv_gemm_ops )
239
+ conv_node = _find_fused_node_with_cur_add (node , conv_ops )
236
240
if conv_gemm_node is None :
237
241
# If pre_nodes don't have gemm node, need to check whether have quantizable node before it,
238
242
# if does't have quantizable node before it, we will not insert fake quant before add.
@@ -255,13 +259,17 @@ def reset_input_inf_dtype_to_orig_dtype(node, input_idx):
255
259
if node .input_tensor_infos [0 ] is not None and node .input_tensor_infos [0 ] in conv_gemm_node .output_tensor_infos :
256
260
node .input_tensor_infos [0 ].inf_dtype = node .input_tensor_infos [0 ].orig_dtype
257
261
node .input_tensor_force_inf_dtype [0 ] = node .input_tensor_infos [0 ].inf_dtype
258
- # set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
259
- reset_input_inf_dtype_to_orig_dtype (node , 1 )
262
+ # TODO: set another input's dtype for conv nodes when oneDNN is ready.
263
+ if conv_node is None :
264
+ # set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
265
+ reset_input_inf_dtype_to_orig_dtype (node , 1 )
260
266
elif node .input_tensor_infos [1 ] is not None and node .input_tensor_infos [1 ] in conv_gemm_node .output_tensor_infos :
261
267
node .input_tensor_infos [1 ].inf_dtype = node .input_tensor_infos [1 ].orig_dtype
262
268
node .input_tensor_force_inf_dtype [1 ] = node .input_tensor_infos [1 ].inf_dtype
263
- # set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
264
- reset_input_inf_dtype_to_orig_dtype (node , 0 )
269
+ # TODO: set another input's dtype for conv nodes when oneDNN is ready.
270
+ if conv_node is None :
271
+ # set another input's dtype, if another's input is from non-quantizable op, we can remove the fake quant.
272
+ reset_input_inf_dtype_to_orig_dtype (node , 0 )
265
273
266
274
# get a default recipe
267
275
def get_default_recipe (nodes ):
0 commit comments