Skip to content

Commit 8572e1f

Browse files
authored
Optimize GQA for IAKV (#3185)
1 parent 5bfee4b commit 8572e1f

File tree

1 file changed

+87
-15
lines changed

1 file changed

+87
-15
lines changed

csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,28 @@ void reduce_head_half(
145145
}
146146
#endif
147147

148+
template <typename T>
149+
void reduce_head(
150+
const T* q_ptr_start,
151+
int64_t kv_head_group_size,
152+
const T* k_ptr_start,
153+
float* attn_w_pos,
154+
int attn_w_stride,
155+
int64_t head_size,
156+
bool store_key,
157+
T* k_cache_start) {
158+
for (auto i = 0; i < kv_head_group_size; i++) {
159+
attn_w_pos[i * attn_w_stride] = 0;
160+
reduce_head<T>(
161+
q_ptr_start + i * head_size,
162+
k_ptr_start,
163+
attn_w_pos + i * attn_w_stride,
164+
head_size,
165+
store_key,
166+
k_cache_start);
167+
}
168+
}
169+
148170
/*
149171
*reduce the attention_weights with the value embedding by the dimension of
150172
*head_size for every head
@@ -170,6 +192,32 @@ void mul_attenion_weights_and_value_of_head(
170192
}
171193
}
172194

195+
template <typename T, typename T1>
196+
void mul_attenion_weights_and_value_of_head(
197+
float* attn_w,
198+
int attn_w_stride,
199+
const T* v_ptr_start,
200+
T1* attn_out_start,
201+
int attn_out_strideH,
202+
int kv_head_group_size,
203+
int64_t head_size,
204+
bool store_value,
205+
T* v_cache_start,
206+
uint8_t* flag_access) {
207+
for (auto i = 0; i < kv_head_group_size; i++) {
208+
mul_attenion_weights_and_value_of_head<T, T1>(
209+
attn_w[i * attn_w_stride],
210+
v_ptr_start,
211+
attn_out_start + i * attn_out_strideH,
212+
head_size,
213+
store_value,
214+
v_cache_start,
215+
flag_access[i]);
216+
if (flag_access[i] == 0)
217+
flag_access[i] = 1;
218+
}
219+
}
220+
173221
#if defined(CPU_CAPABILITY_AVX512)
174222
template <>
175223
void mul_attenion_weights_and_value_of_head(
@@ -594,17 +642,21 @@ scale_dot_product_for_indirect_access_kv_cache(
594642
#pragma omp parallel for collapse(3)
595643
for (auto block_id = 0; block_id < kv_block_count; block_id++) {
596644
for (auto bi = 0; bi < bs; bi++) {
597-
for (auto hi = 0; hi < head_num; hi++) {
645+
for (auto head_group_start = 0; head_group_start < head_num;
646+
head_group_start += group_size) {
598647
auto k_start = block_id * kv_block_size;
599648
auto block_size = std::min(kv_block_size, seq_len - k_start);
600649
auto query_ti = 0;
601650
for (auto ti = k_start; ti < k_start + block_size; ti++) {
602-
auto kv_hi = hi / group_size; // maping the query head to
603-
// key/value head to support MGA/MQA
651+
auto kv_hi = head_group_start /
652+
group_size; // maping the query head to
653+
// key/value head to support MGA/MQA
604654
auto q_ptr_start = q_ptr +
605655
(bi * cur_len + query_ti) * head_num * head_size +
606-
hi * head_size;
607-
auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
656+
head_group_start * head_size;
657+
auto attn_w_stride2 = cur_len * seq_len;
658+
auto attn_w_stride =
659+
(bi * head_num + head_group_start) * attn_w_stride2;
608660
auto attn_w_pos =
609661
attn_w_ptr + attn_w_stride + query_ti * seq_len + ti;
610662
attn_w_pos[0] = 0.0f;
@@ -632,8 +684,10 @@ scale_dot_product_for_indirect_access_kv_cache(
632684
kv_hi * head_size;
633685
reduce_head<QT>(
634686
q_ptr_start,
687+
group_size,
635688
k_ptr_start,
636689
attn_w_pos,
690+
attn_w_stride2,
637691
head_size,
638692
true,
639693
kc_head_start);
@@ -644,8 +698,10 @@ scale_dot_product_for_indirect_access_kv_cache(
644698
kv_hi * head_size;
645699
reduce_head<QT>(
646700
q_ptr_start,
701+
group_size,
647702
k_ptr_start,
648703
attn_w_pos,
704+
attn_w_stride2,
649705
head_size,
650706
false,
651707
nullptr);
@@ -662,8 +718,10 @@ scale_dot_product_for_indirect_access_kv_cache(
662718
k_cache_ptr + kc_t_beam_start + kv_hi * head_size;
663719
reduce_head<QT>(
664720
q_ptr_start,
721+
group_size,
665722
kc_head_start,
666723
attn_w_pos,
724+
attn_w_stride2,
667725
head_size,
668726
false,
669727
nullptr);
@@ -737,6 +795,7 @@ scale_dot_product_for_indirect_access_kv_cache(
737795
auto private_attn_out_flag =
738796
at::zeros({thread_numbers, bs, head_num}, at::kByte);
739797
auto flag_access = private_attn_out_flag.accessor<uint8_t, 3>();
798+
uint8_t* flag_access_ptr = flag_access.data();
740799
auto private_attn_out_ptr = private_attn_outs.data_ptr<float>();
741800
// private_attn_outs.numel());
742801
auto attn_outs_stride_priv = bs * head_num * cur_len * head_size;
@@ -747,7 +806,7 @@ scale_dot_product_for_indirect_access_kv_cache(
747806
#pragma omp parallel for collapse(3)
748807
for (auto block_id = 0; block_id < kv_block_count; block_id++) {
749808
for (auto bi = 0; bi < bs; bi++) {
750-
for (auto hi = 0; hi < head_num; hi++) {
809+
for (auto hi = 0; hi < head_num; hi += group_size) {
751810
auto thread_id = 0;
752811
if (kv_block_size < seq_len)
753812
thread_id = omp_get_thread_num();
@@ -757,15 +816,19 @@ scale_dot_product_for_indirect_access_kv_cache(
757816
for (auto vi = v_start; vi < v_start + block_size; vi++) {
758817
auto kv_hi = hi / group_size; // maping the query head to
759818
// key/value head to support MGA/MQA
760-
auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
819+
auto attn_w_stride2 = cur_len * seq_len;
820+
auto attn_w_stride = (bi * head_num + hi) * attn_w_stride2;
761821
auto attn_w_query_start =
762-
attn_w_ptr + attn_w_stride + query_ti * seq_len;
822+
attn_w_ptr + attn_w_stride + query_ti * seq_len + vi;
763823
// calculate weighted value and store the result to attn_outs[bs,
764824
// head_num, cur_len, head_size]
825+
auto attn_out_head_stride2 = cur_len * head_size;
765826
auto attn_out_head_stride = thread_id * attn_outs_stride_priv +
766-
(bi * head_num + hi) * cur_len * head_size;
827+
(bi * head_num + hi) * attn_out_head_stride2;
767828
auto attn_out_start = private_attn_out_ptr + attn_out_head_stride +
768829
query_ti * head_size;
830+
auto flag_access_start = flag_access_ptr +
831+
head_num * bs * thread_id + head_num * bi + hi;
769832

770833
auto vc_token_start = vi * kc_token_stride;
771834
auto beam = need_update_beam_idx ? new_beam_idx[bi][vi] : 0;
@@ -787,13 +850,16 @@ scale_dot_product_for_indirect_access_kv_cache(
787850
(bi * cur_len + vi - offset) * kv_head * head_size +
788851
kv_hi * head_size;
789852
mul_attenion_weights_and_value_of_head<VT, float>(
790-
attn_w_query_start[vi],
853+
attn_w_query_start,
854+
attn_w_stride2,
791855
v_ptr_start,
792856
attn_out_start,
857+
attn_out_head_stride2,
858+
group_size,
793859
head_size,
794860
true,
795861
v_cache_head_start,
796-
flag_access[thread_id][bi][hi]);
862+
flag_access_start);
797863
} else if (vi < query_ti + offset) { // caculate attention
798864
// values for the past
799865
// token
@@ -802,13 +868,16 @@ scale_dot_product_for_indirect_access_kv_cache(
802868
(bi * cur_len + vi - offset) * kv_head * head_size +
803869
kv_hi * head_size;
804870
mul_attenion_weights_and_value_of_head<VT, float>(
805-
attn_w_query_start[vi],
871+
attn_w_query_start,
872+
attn_w_stride2,
806873
v_ptr_start,
807874
attn_out_start,
875+
attn_out_head_stride2,
876+
group_size,
808877
head_size,
809878
false,
810879
nullptr,
811-
flag_access[thread_id][bi][hi]);
880+
flag_access_start);
812881
} else {
813882
auto vc_t_beam_start =
814883
vc_token_start + beam * kv_head * head_size;
@@ -822,13 +891,16 @@ scale_dot_product_for_indirect_access_kv_cache(
822891
auto v_cache_head_start =
823892
v_cache_ptr + vc_t_beam_start + kv_hi * head_size;
824893
mul_attenion_weights_and_value_of_head<VT, float>(
825-
attn_w_query_start[vi],
894+
attn_w_query_start,
895+
attn_w_stride2,
826896
v_cache_head_start,
827897
attn_out_start,
898+
attn_out_head_stride2,
899+
group_size,
828900
head_size,
829901
false,
830902
nullptr,
831-
flag_access[thread_id][bi][hi]);
903+
flag_access_start);
832904
}
833905
}
834906
if (flag_access[thread_id][bi][hi] == 0)

0 commit comments

Comments
 (0)