|
4 | 4 | #include <torch/all.h>
|
5 | 5 | #include <torch/csrc/autograd/function.h>
|
6 | 6 | #include <limits>
|
7 |
| -#include "vec/vec.h" |
8 | 7 | #include "../../utils/isa_utils.h"
|
| 8 | +#include "vec/vec.h" |
9 | 9 |
|
10 | 10 | namespace torch_ipex {
|
11 | 11 | namespace cpu {
|
@@ -1346,7 +1346,8 @@ first_token_masked_mha(
|
1346 | 1346 | auto attn_outputs = at::Tensor();
|
1347 | 1347 | auto attn_weights = at::Tensor();
|
1348 | 1348 | if ((key.scalar_type() == at::kFloat || key.scalar_type() == at::kBFloat16 ||
|
1349 |
| - (key.scalar_type() == at::kHalf && utils::isa_has_avx512_fp16_support())) && |
| 1349 | + (key.scalar_type() == at::kHalf && |
| 1350 | + utils::isa_has_avx512_fp16_support())) && |
1350 | 1351 | attention_mask.stride(-1) == 1) {
|
1351 | 1352 | query = query.transpose(1, 2);
|
1352 | 1353 | key = key.transpose(1, 2);
|
@@ -1447,27 +1448,26 @@ masked_multihead_self_attention_kernel_impl(
|
1447 | 1448 | query.size(0); // record the promt bs info
|
1448 | 1449 |
|
1449 | 1450 | } else if (offset > 0 && offset + cur_len > cache_size) {
|
1450 |
| - auto new_cache_size = cache_size * 2 + 2; |
| 1451 | + auto new_cache_size = cache_size * 2; |
1451 | 1452 | auto new_key_cache = at::empty(
|
1452 | 1453 | {new_cache_size, beam_batch, key.size(2), key.size(3)}, key.options());
|
1453 | 1454 | auto new_value_cache = at::empty(
|
1454 | 1455 | {new_cache_size, beam_batch, value.size(2), value.size(3)},
|
1455 | 1456 | value.options());
|
1456 | 1457 | auto new_beam_idx =
|
1457 |
| - at::zeros({new_cache_size, beam_batch}, beam_idx.options()); |
| 1458 | + at::zeros({new_cache_size + 2, beam_batch}, beam_idx.options()); |
1458 | 1459 | new_key_cache.slice(0, 0, cache_size).copy_(key_cache);
|
1459 | 1460 | new_value_cache.slice(0, 0, cache_size).copy_(value_cache);
|
1460 |
| - new_beam_idx.slice(0, 0, cache_size).copy_(beam_idx); |
| 1461 | + new_beam_idx.slice(0, 0, cache_size + 2).copy_(beam_idx); |
1461 | 1462 | auto new_beam_idx_access = new_beam_idx.accessor<long, 2>();
|
1462 | 1463 | auto beam_idx_access = beam_idx.accessor<long, 2>();
|
1463 | 1464 | for (auto i = offset; i < new_cache_size; i++) {
|
1464 | 1465 | for (auto j = 0; j < beam_batch; j++) {
|
1465 | 1466 | new_beam_idx_access[i][j] = beam_idx_access[0][j];
|
1466 | 1467 | }
|
1467 | 1468 | }
|
1468 |
| - new_beam_idx_access[new_cache_size - 2][0] = |
1469 |
| - beam_idx_access[cache_size - 2][0]; |
1470 |
| - new_beam_idx_access[new_cache_size - 1][0] = |
| 1469 | + new_beam_idx_access[new_cache_size][0] = beam_idx_access[cache_size - 2][0]; |
| 1470 | + new_beam_idx_access[new_cache_size + 1][0] = |
1471 | 1471 | beam_idx_access[cache_size - 1][0];
|
1472 | 1472 | key_cache = new_key_cache;
|
1473 | 1473 | value_cache = new_value_cache;
|
|
0 commit comments