@@ -116,7 +116,8 @@ at::Tensor convolution_kernel(
116
116
at::IntArrayRef padding,
117
117
at::IntArrayRef dilation,
118
118
int64_t groups,
119
- const ideep::attr_t & attr) {
119
+ const ideep::attr_t & attr,
120
+ at::MemoryFormat memory_format) {
120
121
// Base convolution kernel, this base kernel will not change input's format,
121
122
// so make sure you has make process the input's format before call this
122
123
// function, the output wil has same format with input.
@@ -132,9 +133,8 @@ at::Tensor convolution_kernel(
132
133
133
134
at::Tensor output;
134
135
if (input.dim () != 3 ) {
135
- output = at::empty (
136
- output_sizes,
137
- input.options ().memory_format (input.suggest_memory_format ()));
136
+ output =
137
+ at::empty (output_sizes, input.options ().memory_format (memory_format));
138
138
} else {
139
139
// This a temporary workaround before channels last 1D is formally supported
140
140
// in PyTorch. We will force to return nwc output.
@@ -164,7 +164,8 @@ at::Tensor convolution_forward_impl(
164
164
c10::optional<at::IntArrayRef> kernel_size,
165
165
c10::optional<at::IntArrayRef> padding,
166
166
c10::optional<at::IntArrayRef> stride,
167
- c10::optional<at::IntArrayRef> dilation) {
167
+ c10::optional<at::IntArrayRef> dilation,
168
+ c10::optional<bool > weight_channels_last) {
168
169
#if defined(IPEX_DISP_OP)
169
170
printf (" torch_ipex::convolution_forward_impl\n " );
170
171
#endif
@@ -385,7 +386,8 @@ at::Tensor IPEXConvolutionOp::_forward(
385
386
c10::optional<at::IntArrayRef> kernel_size,
386
387
c10::optional<at::IntArrayRef> padding,
387
388
c10::optional<at::IntArrayRef> stride,
388
- c10::optional<at::IntArrayRef> dilation) {
389
+ c10::optional<at::IntArrayRef> dilation,
390
+ c10::optional<bool > weight_channels_last) {
389
391
at::AutoDispatchBelowADInplaceOrView g;
390
392
RECORD_FUNCTION (
391
393
" IPEXConvolutionOp::_forward" , c10::ArrayRef<c10::IValue>({}));
@@ -401,7 +403,8 @@ at::Tensor IPEXConvolutionOp::_forward(
401
403
kernel_size,
402
404
padding,
403
405
stride,
404
- dilation);
406
+ dilation,
407
+ weight_channels_last);
405
408
}
406
409
407
410
at::Tensor IPEXConvolutionOp::forward (
@@ -413,7 +416,8 @@ at::Tensor IPEXConvolutionOp::forward(
413
416
c10::optional<at::IntArrayRef> kernel_size,
414
417
c10::optional<at::IntArrayRef> padding,
415
418
c10::optional<at::IntArrayRef> stride,
416
- c10::optional<at::IntArrayRef> dilation) {
419
+ c10::optional<at::IntArrayRef> dilation,
420
+ c10::optional<bool > weight_channels_last) {
417
421
RECORD_FUNCTION (" IPEXConvolutionOp::forward" , c10::ArrayRef<c10::IValue>({}));
418
422
419
423
at::AutoDispatchBelowADInplaceOrView g;
@@ -432,7 +436,8 @@ at::Tensor IPEXConvolutionOp::forward(
432
436
kernel_size,
433
437
padding,
434
438
stride,
435
- dilation);
439
+ dilation,
440
+ weight_channels_last);
436
441
}
437
442
438
443
torch::autograd::variable_list IPEXConvolutionOp::backward (
@@ -463,6 +468,7 @@ torch::autograd::variable_list IPEXConvolutionOp::backward(
463
468
at::Tensor (),
464
469
at::Tensor (),
465
470
at::Tensor (),
471
+ at::Tensor (),
466
472
at::Tensor ()};
467
473
}
468
474
@@ -474,7 +480,8 @@ at::Tensor convolution_forward(
474
480
c10::optional<at::IntArrayRef> kernel_size,
475
481
c10::optional<at::IntArrayRef> padding,
476
482
c10::optional<at::IntArrayRef> stride,
477
- c10::optional<at::IntArrayRef> dilation) {
483
+ c10::optional<at::IntArrayRef> dilation,
484
+ c10::optional<bool > weight_channels_last) {
478
485
if (at::GradMode::is_enabled ()) {
479
486
return IPEXConvolutionOp::apply (
480
487
input,
@@ -484,7 +491,8 @@ at::Tensor convolution_forward(
484
491
kernel_size,
485
492
padding,
486
493
stride,
487
- dilation);
494
+ dilation,
495
+ weight_channels_last);
488
496
}
489
497
return IPEXConvolutionOp::_forward (
490
498
input,
@@ -494,7 +502,8 @@ at::Tensor convolution_forward(
494
502
kernel_size,
495
503
padding,
496
504
stride,
497
- dilation);
505
+ dilation,
506
+ weight_channels_last);
498
507
}
499
508
500
509
at::Tensor convolution_forward_meta (
@@ -505,11 +514,12 @@ at::Tensor convolution_forward_meta(
505
514
c10::optional<at::IntArrayRef> kernel_size,
506
515
c10::optional<at::IntArrayRef> padding,
507
516
c10::optional<at::IntArrayRef> stride,
508
- c10::optional<at::IntArrayRef> dilation) {
517
+ c10::optional<at::IntArrayRef> dilation,
518
+ c10::optional<bool > weight_channels_last) {
509
519
TORCH_CHECK (
510
520
kernel_size.has_value () && padding.has_value () && stride.has_value () &&
511
- dilation.has_value (),
512
- " kernel_size, padding, stride and dilation must have value for convolution_forward_meta" );
521
+ dilation.has_value () && weight_channels_last. has_value () ,
522
+ " kernel_size, padding, stride, dilation and weight_channels_last must have value for convolution_forward_meta" );
513
523
auto input_size = input.sym_sizes ();
514
524
c10::SymDimVector output_sizes = calc_conv_output_size (
515
525
input_size,
@@ -518,6 +528,28 @@ at::Tensor convolution_forward_meta(
518
528
stride.value (),
519
529
dilation.value ());
520
530
auto output = at::empty_symint (output_sizes, input.options ());
531
+
532
+ bool use_channels_last =
533
+ input.suggest_memory_format () == at::MemoryFormat::ChannelsLast ||
534
+ input.suggest_memory_format () == at::MemoryFormat::ChannelsLast3d ||
535
+ weight_channels_last.value ();
536
+
537
+ auto memory_format = at::MemoryFormat::Contiguous;
538
+ if (use_channels_last) {
539
+ if (input.dim () == 4 ) {
540
+ memory_format = at::MemoryFormat::ChannelsLast;
541
+ } else if (input.dim () == 5 ) {
542
+ memory_format = at::MemoryFormat::ChannelsLast3d;
543
+ }
544
+ }
545
+
546
+ if (!is_channels_last_1d (output)) {
547
+ output = output.contiguous (memory_format);
548
+ if (input.dim () == 3 ) {
549
+ output = to_channels_last_1d (output);
550
+ }
551
+ }
552
+
521
553
return output;
522
554
}
523
555
@@ -535,7 +567,8 @@ at::Tensor convolution_forward(
535
567
c10::optional<at::IntArrayRef> kernel_size,
536
568
c10::optional<at::IntArrayRef> padding,
537
569
c10::optional<at::IntArrayRef> stride,
538
- c10::optional<at::IntArrayRef> dilation) {
570
+ c10::optional<at::IntArrayRef> dilation,
571
+ c10::optional<bool > weight_channels_last) {
539
572
c10::impl::ExcludeDispatchKeyGuard no_autocastCPU (DispatchKey::AutocastCPU);
540
573
static auto op = torch::Dispatcher::singleton ()
541
574
.findSchemaOrThrow (" torch_ipex::convolution_forward" , " " )
@@ -551,7 +584,8 @@ at::Tensor convolution_forward(
551
584
kernel_size,
552
585
padding,
553
586
stride,
554
- dilation);
587
+ dilation,
588
+ weight_channels_last);
555
589
}
556
590
557
591
} // namespace autocast
@@ -562,7 +596,7 @@ namespace {
562
596
TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
563
597
m.def (
564
598
" convolution_forward(Tensor input, Tensor weight, Tensor? bias, "
565
- " Tensor W_prepack, int[]? kernel_size, int[]? padding, int[]? stride, int[]? dilation) -> Tensor" );
599
+ " Tensor W_prepack, int[]? kernel_size, int[]? padding, int[]? stride, int[]? dilation, bool? weight_channels_last ) -> Tensor" );
566
600
m.impl (
567
601
" convolution_forward" ,
568
602
c10::DispatchKey::Autograd,
0 commit comments