Skip to content

Commit 7076524

Browse files
EikanWangWei-Lin-Inteljianan-gu
authored
Enable Transpose-free BF16 MHA based on the mha_calc fusions (#992) (#1048)
* Enable Transpose-free BF16 MHA based on the mha_calc fusions (#992) * Enable Transpose-free BF16 MHA based on the mha_calc fusions * Enable transfree MHA & BF16 Matmul * Add test cases for Transfree MHA and OutTransfree BF16 BMM * fall back to BF16 MHA based on the mha_calc fusions * merge cpu-device llga * Combine the Matmul kernel with contiguous checks * Revise the descriptions for some kernels Co-authored-by: Wang Weihan <[email protected]> * remove aligned restriction using loadu/storeu and add UT (#938) Co-authored-by: Wei Lin <[email protected]> Co-authored-by: jianan-gu <[email protected]>
1 parent d866108 commit 7076524

File tree

15 files changed

+1110
-180
lines changed

15 files changed

+1110
-180
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/AddLayerNorm.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,9 @@ at::Tensor dil_add_layernorm(
4848
break;
4949
}
5050
}
51-
// Only support 64byte aligned
52-
bool aligned_64_bytes = a.size(a.ndimension() - 1) % 16 == 0 &&
53-
b.size(b.ndimension() - 1) % 16 == 0;
5451
// Only support contiguous tensor
5552
bool is_contiguous = a.is_contiguous() && b.is_contiguous();
56-
if (no_broadcast && aligned_64_bytes && is_contiguous && alpha == 1.0f) {
53+
if (no_broadcast && is_contiguous && alpha == 1.0f) {
5754
return AddLayerNorm(
5855
a, b, alpha, normalized_shape, weight_opt, bias_opt, eps);
5956
} else {

intel_extension_for_pytorch/csrc/aten/cpu/kernels/DivSoftmaxKrnl.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ using namespace torch_ipex::cpu::kernel;
1818
* There are some assumptions for this operator.
1919
* - The reduce dimension for softmax is the last dimension
2020
* - The reduce dimension for softmax is the leading dimension
21-
* - The elements number of the reduce dimension for softmax is n*16
2221
* - The input tensors are contiguous
2322
* - The number of the input tensor dimension should be >=2
2423
* - The mask b can be expand_as a with the mask_reshape (bs :: seq_length),

intel_extension_for_pytorch/csrc/cpu/ideep/ideep/operators/matmul.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ struct matmul_forward : public dnnl::matmul,
315315

316316
dst_data_type = dst_type == data_type::undef ? dst_data_type : dst_type;
317317
tensor::desc dst_desc(dst_dims, dst_data_type, tag::any);
318+
if (!dst.is_empty())
319+
dst_desc = dst.get_desc().to_type(dst_data_type);
318320
auto key = utils::create_key(
319321
src_desc,
320322
weights_desc,

intel_extension_for_pytorch/csrc/cpu/vec/vec512/perf_kernel/add_layernorm.h

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,33 @@ std::pair<float, float> _add_and_compute_mean_var(
2121
float* out) {
2222
// compute add and mean/var of the value after add
2323
// we should firstly store add value
24-
auto vec_a = _load_f32_data(a_ptr);
25-
auto vec_b = _load_f32_data(b_ptr);
24+
auto vec_a = _loadu(a_ptr);
25+
auto vec_b = _loadu(b_ptr);
2626
auto vec_add = _mm512_add_ps(vec_a, vec_b);
2727
auto vec_acc_mean = vec_add;
2828
auto vec_acc_pow = _mm512_mul_ps(vec_add, vec_add);
29-
_mm512_store_ps(out, vec_add);
29+
_mm512_storeu_ps(out, vec_add);
3030

3131
int i = 16;
3232
for (; i <= size - 16; i += 16) {
33-
vec_a = _load_f32_data(a_ptr + i);
34-
vec_b = _load_f32_data(b_ptr + i);
33+
vec_a = _loadu(a_ptr + i);
34+
vec_b = _loadu(b_ptr + i);
3535
vec_add = _mm512_add_ps(vec_a, vec_b);
3636
vec_acc_mean = _mm512_add_ps(vec_add, vec_acc_mean);
37-
_mm512_store_ps(out + i, vec_add);
37+
_mm512_storeu_ps(out + i, vec_add);
3838
vec_acc_pow = _mm512_fmadd_ps(vec_add, vec_add, vec_acc_pow);
3939
}
4040

4141
if (i < size) {
4242
__mmask16 mask = (1 << (size - i)) - 1;
43-
vec_a = _maskz_load_f32_data(a_ptr + i, mask);
44-
vec_b = _maskz_load_f32_data(b_ptr + i, mask);
43+
vec_a = _maskz_loadu(a_ptr + i, mask);
44+
vec_b = _maskz_loadu(b_ptr + i, mask);
4545
vec_add = _mm512_add_ps(vec_a, vec_b);
4646
auto vec_zero = _mm512_set1_ps(0);
47-
_mm512_mask_store_ps(out + i, mask, vec_add);
48-
vec_acc_mean = _mm512_maskz_add_ps(mask, vec_add, vec_acc_mean);
49-
vec_acc_pow = _mm512_maskz_fmadd_ps(mask, vec_add, vec_add, vec_acc_pow);
47+
48+
vec_acc_mean = _mm512_add_ps(vec_add, vec_acc_mean);
49+
_mm512_mask_storeu_ps(out + i, mask, vec_add);
50+
vec_acc_pow = _mm512_fmadd_ps(vec_add, vec_add, vec_acc_pow);
5051
}
5152
float mean_var = _mm512_reduce_add_ps(vec_acc_mean) / float(size);
5253
float var_val = _mm512_reduce_add_ps(vec_acc_pow);
@@ -68,35 +69,35 @@ void _normalize_kernel(
6869
auto vec_bias = _mm512_set1_ps(bias);
6970
int i = 0;
7071
for (; i <= size - 16; i += 16) {
71-
auto vec_input = _load_f32_data(input_ptr + i);
72+
auto vec_input = _loadu(input_ptr + i);
7273
auto vec_gamma = vec_one;
7374
auto vec_beta = vec_zero;
7475
if (gamma_ptr) {
75-
vec_gamma = _load_f32_data(gamma_ptr + i);
76+
vec_gamma = _loadu(gamma_ptr + i);
7677
}
7778
if (beta_ptr) {
78-
vec_beta = _load_f32_data(beta_ptr + i);
79+
vec_beta = _loadu(beta_ptr + i);
7980
}
8081
//(a_ptr[i] * scale + bias) * gamma + beta;
8182
auto vec_norm = _mm512_fmadd_ps(vec_input, vec_scale, vec_bias);
8283
auto vec_res = _mm512_fmadd_ps(vec_norm, vec_gamma, vec_beta);
83-
_store_data(out_ptr + i, vec_res);
84+
_storeu(out_ptr + i, vec_res);
8485
}
8586
if (i < size) {
8687
__mmask16 mask = (1 << (size - i)) - 1;
87-
auto vec_input = _maskz_load_f32_data(input_ptr + i, mask);
88+
auto vec_input = _maskz_loadu(input_ptr + i, mask);
8889
auto vec_gamma = vec_one;
8990
auto vec_beta = vec_zero;
90-
if (!gamma_ptr) {
91-
vec_gamma = _maskz_load_f32_data(gamma_ptr + i, mask);
91+
if (gamma_ptr) {
92+
vec_gamma = _maskz_loadu(gamma_ptr + i, mask);
9293
}
93-
if (!beta_ptr) {
94-
vec_beta = _maskz_load_f32_data(beta_ptr + i, mask);
94+
if (beta_ptr) {
95+
vec_beta = _maskz_loadu(beta_ptr + i, mask);
9596
}
9697
//(a_ptr[i] * scale + bias) * gamma + beta;
97-
auto vec_norm = _mm512_maskz_fmadd_ps(mask, vec_input, vec_scale, vec_bias);
98-
auto vec_res = _mm512_maskz_fmadd_ps(mask, vec_norm, vec_gamma, vec_beta);
99-
_mask_store_data(out_ptr + i, vec_res, mask);
98+
auto vec_norm = _mm512_fmadd_ps(vec_input, vec_scale, vec_bias);
99+
auto vec_res = _mm512_fmadd_ps(vec_norm, vec_gamma, vec_beta);
100+
_mask_storeu(out_ptr + i, vec_res, mask);
100101
}
101102
}
102103

intel_extension_for_pytorch/csrc/cpu/vec/vec512/perf_kernel/add_softmax.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ inline void _dil_div_add_reduce_max_fusion_kernel(
9898
vec_b = _loadu(b + i);
9999
vec_out = _mm512_fmadd_ps(vec_a, vec_r_dim_per_head, vec_b);
100100
vec_ps_min = _mm512_max_ps(vec_ps_min, vec_out);
101-
_mm512_store_ps(out + i, vec_out);
101+
_mm512_storeu_ps(out + i, vec_out);
102102
}
103103

104104
if (i < size) {
@@ -107,7 +107,7 @@ inline void _dil_div_add_reduce_max_fusion_kernel(
107107
vec_b = _maskz_loadu(b + i, mask);
108108
vec_out = _mm512_fmadd_ps(vec_a, vec_r_dim_per_head, vec_b);
109109
vec_ps_min = _mm512_mask_max_ps(vec_ps_min, mask, vec_out, vec_ps_min);
110-
_mm512_mask_store_ps(out + i, mask, vec_out);
110+
_mm512_mask_storeu_ps(out + i, mask, vec_out);
111111
}
112112

113113
// NOTE: _mm512_reduce_max_ps is sequence instruction
@@ -134,22 +134,22 @@ inline void _dil_maskedfill_div_max_fusion_kernel(
134134

135135
int i = 0;
136136
for (; i <= size - 16; i += 16) {
137-
vec_a = _load_f32_data(a + i);
138-
vec_b = _load_f32_data(b + i);
137+
vec_a = _loadu(a + i);
138+
vec_b = _loadu(b + i);
139139
__mmask16 fill_mask = _mm512_cmp_ps_mask(vec_b, mask_c, 12);
140140
vec_out = _mm512_mask_div_ps(vec_fill, fill_mask, vec_a, vec_dim_per_head);
141141
vec_ps_min = _mm512_max_ps(vec_ps_min, vec_out);
142-
_mm512_store_ps(out + i, vec_out);
142+
_mm512_storeu_ps(out + i, vec_out);
143143
}
144144

145145
if (i < size) {
146146
__mmask16 mask = (1 << (size - i)) - 1;
147-
vec_a = _maskz_load_f32_data(a + i, mask);
148-
vec_b = _maskz_load_f32_data(b + i, mask);
147+
vec_a = _maskz_loadu(a + i, mask);
148+
vec_b = _maskz_loadu(b + i, mask);
149149
__mmask16 fill_mask = _mm512_cmp_ps_mask(vec_b, mask_c, 12);
150150
vec_out = _mm512_mask_div_ps(vec_fill, fill_mask, vec_a, vec_dim_per_head);
151151
vec_ps_min = _mm512_max_ps(vec_ps_min, vec_out);
152-
_mm512_mask_store_ps(out + i, mask, vec_out);
152+
_mm512_mask_storeu_ps(out + i, mask, vec_out);
153153
}
154154

155155
// NOTE: _mm512_reduce_max_ps is sequence instruction

intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,42 @@ namespace cpu {
2323
* @param out Optinal output provided by user for matmul
2424
* @attr Attribute for matmul oneDNN primitive
2525
* @return output Tensor.
26-
*/
26+
* Since oneDNN 2.6.0, AMX and AVX512 brgemm are enabled for the DNNL MATMUL
27+
* primitive if the input tensors are with the following tags:
28+
* 3-dim - abc, acb; 4-dim - abcd, acbd, adbc, abdc.
29+
* If the input tensor has one of the above layouts, the contiguous should NOT
30+
* be applied to avoid unnecessary transpose (copy).
31+
**/
2732
at::Tensor bmm_impl(
2833
const at::Tensor& tensor1,
2934
const at::Tensor& tensor2,
3035
at::Tensor out,
3136
const ideep::attr_t& attr,
3237
const std::vector<ideep::tensor>& postop_tensors,
3338
const float dst_coeff = 1.0f) {
34-
auto tensor1_ = tensor1.is_contiguous() ? tensor1 : tensor1.contiguous();
35-
auto tensor2_ = tensor2.is_contiguous() ? tensor2 : tensor2.contiguous();
39+
// The following conditions are strict to exclude some extreme cases when the
40+
// tensors have the undefined stride values. For the sake of reliability of
41+
// transpose-free Matmul kernel, contiguous will be applied to these tensors.
42+
auto check_tensor_layout = [](at::Tensor tensor) {
43+
// Check if the Tensor is 3-dim or 4-dim
44+
if (tensor.dim() != 3 && tensor.dim() != 4)
45+
return false;
46+
// Check if 'a' is the first dim
47+
for (int64_t i = 1; i < tensor.dim(); ++i) {
48+
if (tensor.stride(0) < tensor.stride(i))
49+
return false;
50+
}
51+
// Check if the tensor has one of the above memory tags:
52+
// The strides of the tensor should not be out of the tensor's ranges.
53+
// 4-dim: 'b' should not be the last dim.
54+
if (tensor.stride(0) * tensor.size(0) != tensor.numel() ||
55+
(tensor.dim() == 4 && tensor.stride(1) == 1))
56+
return false;
57+
return true;
58+
};
59+
auto tensor1_ = check_tensor_layout(tensor1) ? tensor1 : tensor1.contiguous();
60+
auto tensor2_ = check_tensor_layout(tensor2) ? tensor2 : tensor2.contiguous();
61+
3662
const int64_t dim = tensor1.dim();
3763
const ideep::tensor mkldnn_input = itensor_view_from_dense(tensor1_);
3864
const ideep::tensor mkldnn_tensor2 = itensor_view_from_dense(tensor2_);

0 commit comments

Comments
 (0)