Skip to content

Commit 2b898a9

Browse files
authored
add auto channels_last conversion in prepare (#1366) (#1367)
* add auto channels_last conversion in prepare * modify UT
1 parent d211115 commit 2b898a9

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

intel_extension_for_pytorch/quantization/_quantize.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def prepare(
3030
torch.nn.Module
3131
"""
3232
assert isinstance(model, torch.nn.Module), "Only support nn.Module prepare for quantization path"
33+
# auto model channels_last memory format conversion
34+
from ..frontend import auto_channels_last, _convert_convNd_weight_memory_format
35+
if auto_channels_last:
36+
_convert_convNd_weight_memory_format(model)
3337
try:
3438
prepare_model = optimization.fuse(model, inplace=inplace)
3539
prepare_model = linear_bn_fuse(prepare_model, inplace=inplace)

tests/cpu/test_auto_channels_last.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,43 @@ def test_auto_channels_last_resnet50(self):
191191
model_ipex_channels_last_modules = self.get_channels_last_modules(model_ipex)
192192

193193
self.assertEqual(model_channels_last, model_ipex_channels_last_modules)
194+
195+
def test_auto_channels_last_for_int8(self):
196+
conv_module = {1: torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
197+
class ConvNd(torch.nn.Module):
198+
def __init__(self, dim, in_channels, out_channels, kernel_size, stride):
199+
super(ConvNd, self).__init__()
200+
self.conv = conv_module[dim](in_channels, out_channels, kernel_size=kernel_size, stride=stride)
201+
202+
def forward(self, x):
203+
return self.conv(x)
204+
205+
def _test_conv(dim):
206+
input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
207+
x_shape = (2, 3) + input_shapes[dim]
208+
x = torch.randn(x_shape, dtype=torch.float32)
209+
model = ConvNd(dim, 3, 4, 3, 2).eval()
210+
qconfig = ipex.quantization.default_static_qconfig
211+
prepared_model = ipex.quantization.prepare(model, qconfig, x)
212+
# do calibration
213+
y = prepared_model(x)
214+
convert_model = ipex.quantization.convert(prepared_model)
215+
with torch.no_grad():
216+
traced_model = torch.jit.trace(convert_model, x)
217+
traced_model = torch.jit.freeze(traced_model)
218+
for _ in range(3):
219+
y = traced_model(x)
220+
return y
221+
222+
# disable auto channels_last
223+
ipex.disable_auto_channels_last()
224+
self.assertTrue(_test_conv(2).is_contiguous(memory_format = torch.contiguous_format))
225+
self.assertTrue(_test_conv(3).is_contiguous(memory_format = torch.contiguous_format))
226+
227+
# enable auto channels_last
228+
ipex.enable_auto_channels_last()
229+
self.assertTrue(_test_conv(2).is_contiguous(memory_format = torch.channels_last))
230+
self.assertTrue(_test_conv(3).is_contiguous(memory_format = torch.channels_last_3d))
194231

195232
if __name__ == '__main__':
196233
test = unittest.main()

0 commit comments

Comments
 (0)