1
1
import torch
2
2
import torch .nn as nn
3
+ import torch .nn .functional as F
3
4
import copy
4
5
import logging
5
6
@@ -72,10 +73,13 @@ def __init__(self, dense_module):
72
73
self .padding = dense_module .padding
73
74
self .dilation = dense_module .dilation
74
75
self .groups = dense_module .groups
76
+ self .padding_mode = dense_module .padding_mode
77
+ self ._reversed_padding_repeated_twice = dense_module ._reversed_padding_repeated_twice
75
78
self .prepack_input_shape = dense_module .input_shape if hasattr (dense_module , "input_shape" ) else []
76
79
self .weight_channels_last = dense_module .weight .is_contiguous (memory_format = torch .channels_last ) \
77
80
or dense_module .weight .is_contiguous (memory_format = torch .channels_last_3d )
78
81
self .weight_size = dense_module .weight .size ()
82
+ self ._real_padding = self .padding if self .padding_mode == 'zeros' else tuple ([0 ] * (len (self .weight_size ) - 2 ))
79
83
80
84
# TODO: ".clone()" will make weight shared by multiple module not shared anymore
81
85
# related issues: https://github.com/intel-innersource/frameworks.ai.pytorch.ipex-cpu/issues/65
@@ -91,7 +95,7 @@ def __init__(self, dense_module):
91
95
self .register_parameter ('bias' , None )
92
96
# create conv op context
93
97
self .ctx = torch .ops .ipex_prepack .convolution_prepack (
94
- dense_module .weight , self .bias , self .stride , self .padding ,
98
+ dense_module .weight , self .bias , self .stride , self ._real_padding ,
95
99
self .dilation , self .groups ,
96
100
self .weight_channels_last , self .prepack_input_shape
97
101
)
@@ -117,14 +121,32 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
117
121
with torch .no_grad ():
118
122
loaded_weight , loaded_bias , fp32_loaded_weight , weight_trail = _load_from_state_dict_pre_hook (self , state_dict , prefix )
119
123
loaded_ctx = torch .ops .ipex_prepack .convolution_prepack (
120
- loaded_weight , loaded_bias , self .stride , self .padding ,
124
+ loaded_weight , loaded_bias , self .stride , self ._real_padding ,
121
125
self .dilation , self .groups ,
122
126
self .weight_channels_last , self .prepack_input_shape
123
127
)
124
128
_load_from_state_dict_post_hook (self , loaded_ctx , fp32_loaded_weight , weight_trail )
125
129
126
130
def forward (self , x ):
127
- return torch .ops .torch_ipex .convolution_forward (x , self .weight , self .bias , self .ctx .get_data_handle (), self .weight_size , self .padding , self .stride , self .dilation )
131
+ if self .padding_mode != 'zeros' :
132
+ return torch .ops .torch_ipex .convolution_forward (
133
+ F .pad (x , self ._reversed_padding_repeated_twice , mode = self .padding_mode ),
134
+ self .weight ,
135
+ self .bias ,
136
+ self .ctx .get_data_handle (),
137
+ self .weight_size ,
138
+ self ._real_padding ,
139
+ self .stride ,
140
+ self .dilation )
141
+ return torch .ops .torch_ipex .convolution_forward (
142
+ x ,
143
+ self .weight ,
144
+ self .bias ,
145
+ self .ctx .get_data_handle (),
146
+ self .weight_size ,
147
+ self ._real_padding ,
148
+ self .stride ,
149
+ self .dilation )
128
150
129
151
class _IPEXConv1d (_IPEXConvNd ):
130
152
def __init__ (self , dense_module ):
@@ -457,10 +479,13 @@ def record_input_shape_for_prepack(module, sample_input):
457
479
458
480
def hook_function (self , input ):
459
481
# input for linear/conv/transpose conv received here will be Tuple[Tensor]
460
- self .input_shape = input [0 ].shape
482
+ if self in [torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ] and self .padding_mode != 'zeros' :
483
+ self .input_shape = F .pad (input [0 ], self ._reversed_padding_repeated_twice , mode = self .padding_mode ).shape
484
+ else :
485
+ self .input_shape = input [0 ].shape
461
486
462
487
def register_hook_function (module ):
463
- if type (module ) in [torch .nn .Linear , torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .ConvTranspose2d ]:
488
+ if type (module ) in [torch .nn .Linear , torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d , torch . nn . ConvTranspose2d ]:
464
489
module .register_forward_pre_hook (hook_function )
465
490
466
491
def register_hook_function_rec (module ):
0 commit comments