@@ -191,6 +191,43 @@ def test_auto_channels_last_resnet50(self):
191
191
model_ipex_channels_last_modules = self .get_channels_last_modules (model_ipex )
192
192
193
193
self .assertEqual (model_channels_last , model_ipex_channels_last_modules )
194
+
195
+ def test_auto_channels_last_for_int8 (self ):
196
+ conv_module = {1 : torch .nn .Conv1d , 2 : torch .nn .Conv2d , 3 : torch .nn .Conv3d }
197
+ class ConvNd (torch .nn .Module ):
198
+ def __init__ (self , dim , in_channels , out_channels , kernel_size , stride ):
199
+ super (ConvNd , self ).__init__ ()
200
+ self .conv = conv_module [dim ](in_channels , out_channels , kernel_size = kernel_size , stride = stride )
201
+
202
+ def forward (self , x ):
203
+ return self .conv (x )
204
+
205
+ def _test_conv (dim ):
206
+ input_shapes = {1 : (224 ,), 2 : (224 , 224 ), 3 : (55 , 55 , 55 )}
207
+ x_shape = (2 , 3 ) + input_shapes [dim ]
208
+ x = torch .randn (x_shape , dtype = torch .float32 )
209
+ model = ConvNd (dim , 3 , 4 , 3 , 2 ).eval ()
210
+ qconfig = ipex .quantization .default_static_qconfig
211
+ prepared_model = ipex .quantization .prepare (model , qconfig , x )
212
+ # do calibration
213
+ y = prepared_model (x )
214
+ convert_model = ipex .quantization .convert (prepared_model )
215
+ with torch .no_grad ():
216
+ traced_model = torch .jit .trace (convert_model , x )
217
+ traced_model = torch .jit .freeze (traced_model )
218
+ for _ in range (3 ):
219
+ y = traced_model (x )
220
+ return y
221
+
222
+ # disable auto channels_last
223
+ ipex .disable_auto_channels_last ()
224
+ self .assertTrue (_test_conv (2 ).is_contiguous (memory_format = torch .contiguous_format ))
225
+ self .assertTrue (_test_conv (3 ).is_contiguous (memory_format = torch .contiguous_format ))
226
+
227
+ # enable auto channels_last
228
+ ipex .enable_auto_channels_last ()
229
+ self .assertTrue (_test_conv (2 ).is_contiguous (memory_format = torch .channels_last ))
230
+ self .assertTrue (_test_conv (3 ).is_contiguous (memory_format = torch .channels_last_3d ))
194
231
195
232
if __name__ == '__main__' :
196
233
test = unittest .main ()
0 commit comments