Skip to content

Commit 02449cc

Browse files
fix ipex conv issue when padding mode is not zero (#1580)
1 parent 9dd123a commit 02449cc

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

intel_extension_for_pytorch/nn/utils/_weight_prepack.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34
import copy
45
import logging
56

@@ -72,10 +73,13 @@ def __init__(self, dense_module):
7273
self.padding = dense_module.padding
7374
self.dilation = dense_module.dilation
7475
self.groups = dense_module.groups
76+
self.padding_mode = dense_module.padding_mode
77+
self._reversed_padding_repeated_twice = dense_module._reversed_padding_repeated_twice
7578
self.prepack_input_shape = dense_module.input_shape if hasattr(dense_module, "input_shape") else []
7679
self.weight_channels_last = dense_module.weight.is_contiguous(memory_format=torch.channels_last) \
7780
or dense_module.weight.is_contiguous(memory_format=torch.channels_last_3d)
7881
self.weight_size = dense_module.weight.size()
82+
self._real_padding = self.padding if self.padding_mode == 'zeros' else tuple([0] * (len(self.weight_size) - 2 ))
7983

8084
# TODO: ".clone()" will make weight shared by multiple module not shared anymore
8185
# related issues: https://github.com/intel-innersource/frameworks.ai.pytorch.ipex-cpu/issues/65
@@ -91,7 +95,7 @@ def __init__(self, dense_module):
9195
self.register_parameter('bias', None)
9296
# create conv op context
9397
self.ctx = torch.ops.ipex_prepack.convolution_prepack(
94-
dense_module.weight, self.bias, self.stride, self.padding,
98+
dense_module.weight, self.bias, self.stride, self._real_padding,
9599
self.dilation, self.groups,
96100
self.weight_channels_last, self.prepack_input_shape
97101
)
@@ -117,14 +121,32 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
117121
with torch.no_grad():
118122
loaded_weight, loaded_bias, fp32_loaded_weight, weight_trail = _load_from_state_dict_pre_hook(self, state_dict, prefix)
119123
loaded_ctx = torch.ops.ipex_prepack.convolution_prepack(
120-
loaded_weight, loaded_bias, self.stride, self.padding,
124+
loaded_weight, loaded_bias, self.stride, self._real_padding,
121125
self.dilation, self.groups,
122126
self.weight_channels_last, self.prepack_input_shape
123127
)
124128
_load_from_state_dict_post_hook(self, loaded_ctx, fp32_loaded_weight, weight_trail)
125129

126130
def forward(self, x):
127-
return torch.ops.torch_ipex.convolution_forward(x, self.weight, self.bias, self.ctx.get_data_handle(), self.weight_size, self.padding, self.stride, self.dilation)
131+
if self.padding_mode != 'zeros':
132+
return torch.ops.torch_ipex.convolution_forward(
133+
F.pad(x, self._reversed_padding_repeated_twice, mode=self.padding_mode),
134+
self.weight,
135+
self.bias,
136+
self.ctx.get_data_handle(),
137+
self.weight_size,
138+
self._real_padding,
139+
self.stride,
140+
self.dilation)
141+
return torch.ops.torch_ipex.convolution_forward(
142+
x,
143+
self.weight,
144+
self.bias,
145+
self.ctx.get_data_handle(),
146+
self.weight_size,
147+
self._real_padding,
148+
self.stride,
149+
self.dilation)
128150

129151
class _IPEXConv1d(_IPEXConvNd):
130152
def __init__(self, dense_module):
@@ -457,10 +479,13 @@ def record_input_shape_for_prepack(module, sample_input):
457479

458480
def hook_function(self, input):
459481
# input for linear/conv/transpose conv received here will be Tuple[Tensor]
460-
self.input_shape = input[0].shape
482+
if self in [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] and self.padding_mode != 'zeros':
483+
self.input_shape = F.pad(input[0], self._reversed_padding_repeated_twice, mode=self.padding_mode).shape
484+
else:
485+
self.input_shape = input[0].shape
461486

462487
def register_hook_function(module):
463-
if type(module) in [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
488+
if type(module) in [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d]:
464489
module.register_forward_pre_hook(hook_function)
465490

466491
def register_hook_function_rec(module):

tests/cpu/test_weight_prepack.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,33 @@ def get_rand_seed():
3737
class TestPrepackCases(TestCase):
3838
def _test_convolution_inference_base(self, dim):
3939
class ConvNd(torch.nn.Module):
40-
def __init__(self, dim, in_channels, out_channels, kernel_size, stride, padding, dilation, bias, groups):
40+
def __init__(self, dim, in_channels, out_channels, kernel_size, stride, padding, dilation, bias, groups, padding_mode):
4141
super(ConvNd, self).__init__()
42-
self.conv = conv_module[dim](in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=groups)
42+
self.conv = conv_module[dim](
43+
in_channels,
44+
out_channels,
45+
kernel_size=kernel_size,
46+
stride=stride,
47+
padding=padding,
48+
dilation=dilation,
49+
bias=bias,
50+
groups=groups,
51+
padding_mode=padding_mode)
4352

4453
def forward(self, x):
4554
return self.conv(x)
4655
input_shapes = {1: (224,), 2: (224, 224), 3: (55, 55, 55)}
56+
padding_modes = ['zeros', 'reflect']
4757
if dim == 2:
4858
channels_last = torch.channels_last
4959
elif dim == 3:
5060
channels_last = torch.channels_last_3d
5161
if dim == 1:
52-
options = itertools.product([True, False], [1, 2], [1, 4], [True, False], [torch.contiguous_format])
62+
options = itertools.product([True, False], [1, 2], [1, 4], [True, False], [torch.contiguous_format], padding_modes)
5363
else:
54-
options = itertools.product([True, False], [1, 2], [1, 4], [True, False], [torch.contiguous_format, channels_last])
64+
options = itertools.product([True, False], [1, 2], [1, 4], [True, False], [torch.contiguous_format, channels_last], padding_modes)
5565

56-
for bias, dilation, groups, feed_sample_input, memory_format in options:
66+
for bias, dilation, groups, feed_sample_input, memory_format, padding_mode in options:
5767
N = torch.randint(1, 10, (1,)).item()
5868
M = torch.randint(1, 3, (1,)).item() * groups
5969
C = torch.randint(1, 3, (1,)).item() * groups
@@ -68,7 +78,8 @@ def forward(self, x):
6878
padding=1,
6979
dilation=dilation,
7080
bias=bias,
71-
groups=groups).float().eval()
81+
groups=groups,
82+
padding_mode=padding_mode).float().eval()
7283
model = model.to(memory_format=memory_format)
7384
x = x.to(memory_format=memory_format)
7485
if dim == 1:

0 commit comments

Comments
 (0)