Skip to content

Commit e74d7a9

Browse files
authored
Fix the WOQ-INT4 crash issue when the pre-allocated buffer is not enough (#3079)
1 parent fbaa4bc commit e74d7a9

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
#include <torch/all.h>
55
#include <torch/csrc/autograd/function.h>
66
#include <limits>
7-
#include "vec/vec.h"
87
#include "../../utils/isa_utils.h"
8+
#include "vec/vec.h"
99

1010
namespace torch_ipex {
1111
namespace cpu {
@@ -1346,7 +1346,8 @@ first_token_masked_mha(
13461346
auto attn_outputs = at::Tensor();
13471347
auto attn_weights = at::Tensor();
13481348
if ((key.scalar_type() == at::kFloat || key.scalar_type() == at::kBFloat16 ||
1349-
(key.scalar_type() == at::kHalf && utils::isa_has_avx512_fp16_support())) &&
1349+
(key.scalar_type() == at::kHalf &&
1350+
utils::isa_has_avx512_fp16_support())) &&
13501351
attention_mask.stride(-1) == 1) {
13511352
query = query.transpose(1, 2);
13521353
key = key.transpose(1, 2);
@@ -1447,27 +1448,26 @@ masked_multihead_self_attention_kernel_impl(
14471448
query.size(0); // record the promt bs info
14481449

14491450
} else if (offset > 0 && offset + cur_len > cache_size) {
1450-
auto new_cache_size = cache_size * 2 + 2;
1451+
auto new_cache_size = cache_size * 2;
14511452
auto new_key_cache = at::empty(
14521453
{new_cache_size, beam_batch, key.size(2), key.size(3)}, key.options());
14531454
auto new_value_cache = at::empty(
14541455
{new_cache_size, beam_batch, value.size(2), value.size(3)},
14551456
value.options());
14561457
auto new_beam_idx =
1457-
at::zeros({new_cache_size, beam_batch}, beam_idx.options());
1458+
at::zeros({new_cache_size + 2, beam_batch}, beam_idx.options());
14581459
new_key_cache.slice(0, 0, cache_size).copy_(key_cache);
14591460
new_value_cache.slice(0, 0, cache_size).copy_(value_cache);
1460-
new_beam_idx.slice(0, 0, cache_size).copy_(beam_idx);
1461+
new_beam_idx.slice(0, 0, cache_size + 2).copy_(beam_idx);
14611462
auto new_beam_idx_access = new_beam_idx.accessor<long, 2>();
14621463
auto beam_idx_access = beam_idx.accessor<long, 2>();
14631464
for (auto i = offset; i < new_cache_size; i++) {
14641465
for (auto j = 0; j < beam_batch; j++) {
14651466
new_beam_idx_access[i][j] = beam_idx_access[0][j];
14661467
}
14671468
}
1468-
new_beam_idx_access[new_cache_size - 2][0] =
1469-
beam_idx_access[cache_size - 2][0];
1470-
new_beam_idx_access[new_cache_size - 1][0] =
1469+
new_beam_idx_access[new_cache_size][0] = beam_idx_access[cache_size - 2][0];
1470+
new_beam_idx_access[new_cache_size + 1][0] =
14711471
beam_idx_access[cache_size - 1][0];
14721472
key_cache = new_key_cache;
14731473
value_cache = new_value_cache;

examples/cpu/inference/python/llm/single_instance/run_quantization.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,13 @@ def load_image(image_file):
405405

406406
num_beams = 1 if args.greedy else 4
407407
if not hasattr(config, "text_max_length") and args.prompt is None:
408-
config.text_max_length = int(args.input_tokens) + int(args.max_new_tokens)
408+
if not args.benchmark:
409+
if hasattr(config, "max_position_embeddings"):
410+
config.text_max_length = config.max_position_embeddings
411+
else:
412+
config.text_max_length = 2048
413+
else:
414+
config.text_max_length = int(args.input_tokens) + int(args.max_new_tokens)
409415
if model.name == "mpt" and not hasattr(config, "max_seq_len") and args.prompt is None:
410416
config.max_seq_len = int(args.input_tokens) + int(args.max_new_tokens)
411417
if model.name in ["git", "llava"]:
@@ -416,6 +422,7 @@ def load_image(image_file):
416422
if args.lm_head_generation and not hasattr(config, "lm_head_generation"):
417423
config.lm_head_generation = True
418424

425+
419426
user_model = model.get_user_model(config, args.benchmark)
420427

421428
tokenizer = model.get_tokenizer()

0 commit comments

Comments
 (0)