File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
intel_extension_for_pytorch/quantization Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -231,6 +231,10 @@ def _maybe_observe(arg, tensor_info):
231
231
# TODO: do not run this twice on input and output
232
232
if str (tensor_id ) in self .tensor_id_to_observer :
233
233
observer = self .tensor_id_to_observer [str (tensor_id )]
234
+ if isinstance (arg , torch .Tensor ) and arg .dtype != torch .float32 :
235
+ dtype = arg .dtype
236
+ out = observer (arg .float ())
237
+ return out .to (dtype )
234
238
return observer (arg )
235
239
else :
236
240
return arg
@@ -290,7 +294,7 @@ def _observer_output(output, tensor_info):
290
294
tensor_id = tensor_info .id
291
295
if str (tensor_id ) in self .tensor_id_to_observer :
292
296
obs = self .tensor_id_to_observer [str (tensor_id )]
293
- obs (output )
297
+ obs (output . float () )
294
298
if isinstance (outputs , torch .Tensor ):
295
299
tensor_info = seen_q_op_info .output_tensor_infos [0 ]
296
300
_observer_output (outputs , tensor_info )
You can’t perform that action at this time.
0 commit comments