Skip to content

Commit d8723df

Browse files
authored
fix output strides in conv/deconv meta backend (#1508) (#1590)
* fix output strides in conv/deconv meta backend * add UT * add UT * fix UT * fix clang format * fix backward * fix UT
1 parent 023c104 commit d8723df

File tree

7 files changed

+301
-53
lines changed

7 files changed

+301
-53
lines changed

csrc/cpu/aten/Conv.cpp

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ at::Tensor convolution_kernel(
116116
at::IntArrayRef padding,
117117
at::IntArrayRef dilation,
118118
int64_t groups,
119-
const ideep::attr_t& attr) {
119+
const ideep::attr_t& attr,
120+
at::MemoryFormat memory_format) {
120121
// Base convolution kernel, this base kernel will not change input's format,
121122
// so make sure you has make process the input's format before call this
122123
// function, the output wil has same format with input.
@@ -132,9 +133,8 @@ at::Tensor convolution_kernel(
132133

133134
at::Tensor output;
134135
if (input.dim() != 3) {
135-
output = at::empty(
136-
output_sizes,
137-
input.options().memory_format(input.suggest_memory_format()));
136+
output =
137+
at::empty(output_sizes, input.options().memory_format(memory_format));
138138
} else {
139139
// This a temporary workaround before channels last 1D is formally supported
140140
// in PyTorch. We will force to return nwc output.
@@ -164,7 +164,8 @@ at::Tensor convolution_forward_impl(
164164
c10::optional<at::IntArrayRef> kernel_size,
165165
c10::optional<at::IntArrayRef> padding,
166166
c10::optional<at::IntArrayRef> stride,
167-
c10::optional<at::IntArrayRef> dilation) {
167+
c10::optional<at::IntArrayRef> dilation,
168+
c10::optional<bool> weight_channels_last) {
168169
#if defined(IPEX_DISP_OP)
169170
printf("torch_ipex::convolution_forward_impl\n");
170171
#endif
@@ -385,7 +386,8 @@ at::Tensor IPEXConvolutionOp::_forward(
385386
c10::optional<at::IntArrayRef> kernel_size,
386387
c10::optional<at::IntArrayRef> padding,
387388
c10::optional<at::IntArrayRef> stride,
388-
c10::optional<at::IntArrayRef> dilation) {
389+
c10::optional<at::IntArrayRef> dilation,
390+
c10::optional<bool> weight_channels_last) {
389391
at::AutoDispatchBelowADInplaceOrView g;
390392
RECORD_FUNCTION(
391393
"IPEXConvolutionOp::_forward", c10::ArrayRef<c10::IValue>({}));
@@ -401,7 +403,8 @@ at::Tensor IPEXConvolutionOp::_forward(
401403
kernel_size,
402404
padding,
403405
stride,
404-
dilation);
406+
dilation,
407+
weight_channels_last);
405408
}
406409

407410
at::Tensor IPEXConvolutionOp::forward(
@@ -413,7 +416,8 @@ at::Tensor IPEXConvolutionOp::forward(
413416
c10::optional<at::IntArrayRef> kernel_size,
414417
c10::optional<at::IntArrayRef> padding,
415418
c10::optional<at::IntArrayRef> stride,
416-
c10::optional<at::IntArrayRef> dilation) {
419+
c10::optional<at::IntArrayRef> dilation,
420+
c10::optional<bool> weight_channels_last) {
417421
RECORD_FUNCTION("IPEXConvolutionOp::forward", c10::ArrayRef<c10::IValue>({}));
418422

419423
at::AutoDispatchBelowADInplaceOrView g;
@@ -432,7 +436,8 @@ at::Tensor IPEXConvolutionOp::forward(
432436
kernel_size,
433437
padding,
434438
stride,
435-
dilation);
439+
dilation,
440+
weight_channels_last);
436441
}
437442

438443
torch::autograd::variable_list IPEXConvolutionOp::backward(
@@ -463,6 +468,7 @@ torch::autograd::variable_list IPEXConvolutionOp::backward(
463468
at::Tensor(),
464469
at::Tensor(),
465470
at::Tensor(),
471+
at::Tensor(),
466472
at::Tensor()};
467473
}
468474

@@ -474,7 +480,8 @@ at::Tensor convolution_forward(
474480
c10::optional<at::IntArrayRef> kernel_size,
475481
c10::optional<at::IntArrayRef> padding,
476482
c10::optional<at::IntArrayRef> stride,
477-
c10::optional<at::IntArrayRef> dilation) {
483+
c10::optional<at::IntArrayRef> dilation,
484+
c10::optional<bool> weight_channels_last) {
478485
if (at::GradMode::is_enabled()) {
479486
return IPEXConvolutionOp::apply(
480487
input,
@@ -484,7 +491,8 @@ at::Tensor convolution_forward(
484491
kernel_size,
485492
padding,
486493
stride,
487-
dilation);
494+
dilation,
495+
weight_channels_last);
488496
}
489497
return IPEXConvolutionOp::_forward(
490498
input,
@@ -494,7 +502,8 @@ at::Tensor convolution_forward(
494502
kernel_size,
495503
padding,
496504
stride,
497-
dilation);
505+
dilation,
506+
weight_channels_last);
498507
}
499508

500509
at::Tensor convolution_forward_meta(
@@ -505,11 +514,12 @@ at::Tensor convolution_forward_meta(
505514
c10::optional<at::IntArrayRef> kernel_size,
506515
c10::optional<at::IntArrayRef> padding,
507516
c10::optional<at::IntArrayRef> stride,
508-
c10::optional<at::IntArrayRef> dilation) {
517+
c10::optional<at::IntArrayRef> dilation,
518+
c10::optional<bool> weight_channels_last) {
509519
TORCH_CHECK(
510520
kernel_size.has_value() && padding.has_value() && stride.has_value() &&
511-
dilation.has_value(),
512-
"kernel_size, padding, stride and dilation must have value for convolution_forward_meta");
521+
dilation.has_value() && weight_channels_last.has_value(),
522+
"kernel_size, padding, stride, dilation and weight_channels_last must have value for convolution_forward_meta");
513523
auto input_size = input.sym_sizes();
514524
c10::SymDimVector output_sizes = calc_conv_output_size(
515525
input_size,
@@ -518,6 +528,28 @@ at::Tensor convolution_forward_meta(
518528
stride.value(),
519529
dilation.value());
520530
auto output = at::empty_symint(output_sizes, input.options());
531+
532+
bool use_channels_last =
533+
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
534+
input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d ||
535+
weight_channels_last.value();
536+
537+
auto memory_format = at::MemoryFormat::Contiguous;
538+
if (use_channels_last) {
539+
if (input.dim() == 4) {
540+
memory_format = at::MemoryFormat::ChannelsLast;
541+
} else if (input.dim() == 5) {
542+
memory_format = at::MemoryFormat::ChannelsLast3d;
543+
}
544+
}
545+
546+
if (!is_channels_last_1d(output)) {
547+
output = output.contiguous(memory_format);
548+
if (input.dim() == 3) {
549+
output = to_channels_last_1d(output);
550+
}
551+
}
552+
521553
return output;
522554
}
523555

@@ -535,7 +567,8 @@ at::Tensor convolution_forward(
535567
c10::optional<at::IntArrayRef> kernel_size,
536568
c10::optional<at::IntArrayRef> padding,
537569
c10::optional<at::IntArrayRef> stride,
538-
c10::optional<at::IntArrayRef> dilation) {
570+
c10::optional<at::IntArrayRef> dilation,
571+
c10::optional<bool> weight_channels_last) {
539572
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU(DispatchKey::AutocastCPU);
540573
static auto op = torch::Dispatcher::singleton()
541574
.findSchemaOrThrow("torch_ipex::convolution_forward", "")
@@ -551,7 +584,8 @@ at::Tensor convolution_forward(
551584
kernel_size,
552585
padding,
553586
stride,
554-
dilation);
587+
dilation,
588+
weight_channels_last);
555589
}
556590

557591
} // namespace autocast
@@ -562,7 +596,7 @@ namespace {
562596
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
563597
m.def(
564598
"convolution_forward(Tensor input, Tensor weight, Tensor? bias, "
565-
"Tensor W_prepack, int[]? kernel_size, int[]? padding, int[]? stride, int[]? dilation) -> Tensor");
599+
"Tensor W_prepack, int[]? kernel_size, int[]? padding, int[]? stride, int[]? dilation, bool? weight_channels_last) -> Tensor");
566600
m.impl(
567601
"convolution_forward",
568602
c10::DispatchKey::Autograd,

csrc/cpu/aten/Conv.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ at::Tensor convolution_kernel(
2828
at::IntArrayRef padding,
2929
at::IntArrayRef dilation,
3030
int64_t groups,
31-
const ideep::attr_t& attr);
31+
const ideep::attr_t& attr,
32+
at::MemoryFormat memory_format);
3233

3334
std::tuple<at::Tensor, at::Tensor, at::Tensor> convolution_backward_kernel(
3435
const at::Tensor& input,
@@ -70,7 +71,8 @@ class IPEXConvolutionOp : public torch::autograd::Function<IPEXConvolutionOp> {
7071
c10::optional<at::IntArrayRef> kernel_size,
7172
c10::optional<at::IntArrayRef> padding,
7273
c10::optional<at::IntArrayRef> stride,
73-
c10::optional<at::IntArrayRef> dilation);
74+
c10::optional<at::IntArrayRef> dilation,
75+
c10::optional<bool> weight_channels_last);
7476

7577
static at::Tensor forward(
7678
torch::autograd::AutogradContext* ctx,
@@ -81,7 +83,8 @@ class IPEXConvolutionOp : public torch::autograd::Function<IPEXConvolutionOp> {
8183
c10::optional<at::IntArrayRef> kernel_size,
8284
c10::optional<at::IntArrayRef> padding,
8385
c10::optional<at::IntArrayRef> stride,
84-
c10::optional<at::IntArrayRef> dilation);
86+
c10::optional<at::IntArrayRef> dilation,
87+
c10::optional<bool> weight_channels_last);
8588

8689
static torch::autograd::variable_list backward(
8790
torch::autograd::AutogradContext* ctx,
@@ -96,7 +99,8 @@ at::Tensor convolution_forward(
9699
c10::optional<at::IntArrayRef> kernel_size,
97100
c10::optional<at::IntArrayRef> padding,
98101
c10::optional<at::IntArrayRef> stride,
99-
c10::optional<at::IntArrayRef> dilation);
102+
c10::optional<at::IntArrayRef> dilation,
103+
c10::optional<bool> weight_channels_last);
100104

101105
at::Tensor convolution_forward_meta(
102106
const at::Tensor& input,
@@ -106,7 +110,8 @@ at::Tensor convolution_forward_meta(
106110
c10::optional<at::IntArrayRef> kernel_size,
107111
c10::optional<at::IntArrayRef> padding,
108112
c10::optional<at::IntArrayRef> stride,
109-
c10::optional<at::IntArrayRef> dilation);
113+
c10::optional<at::IntArrayRef> dilation,
114+
c10::optional<bool> weight_channels_last);
110115

111116
} // namespace cpu
112117
} // namespace torch_ipex

0 commit comments

Comments
 (0)