Skip to content

Commit 046f7df

Browse files
authored
Convert observer input to float data type if it is not (#1529) (#1564)
1 parent fed42b1 commit 046f7df

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

intel_extension_for_pytorch/quantization/_quantization_state.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ def _maybe_observe(arg, tensor_info):
231231
# TODO: do not run this twice on input and output
232232
if str(tensor_id) in self.tensor_id_to_observer:
233233
observer = self.tensor_id_to_observer[str(tensor_id)]
234+
if isinstance(arg, torch.Tensor) and arg.dtype != torch.float32:
235+
dtype = arg.dtype
236+
out = observer(arg.float())
237+
return out.to(dtype)
234238
return observer(arg)
235239
else:
236240
return arg
@@ -290,7 +294,7 @@ def _observer_output(output, tensor_info):
290294
tensor_id = tensor_info.id
291295
if str(tensor_id) in self.tensor_id_to_observer:
292296
obs = self.tensor_id_to_observer[str(tensor_id)]
293-
obs(output)
297+
obs(output.float())
294298
if isinstance(outputs, torch.Tensor):
295299
tensor_info = seen_q_op_info.output_tensor_infos[0]
296300
_observer_output(outputs, tensor_info)

0 commit comments

Comments
 (0)