@@ -145,6 +145,28 @@ void reduce_head_half(
145
145
}
146
146
#endif
147
147
148
+ template <typename T>
149
+ void reduce_head (
150
+ const T* q_ptr_start,
151
+ int64_t kv_head_group_size,
152
+ const T* k_ptr_start,
153
+ float * attn_w_pos,
154
+ int attn_w_stride,
155
+ int64_t head_size,
156
+ bool store_key,
157
+ T* k_cache_start) {
158
+ for (auto i = 0 ; i < kv_head_group_size; i++) {
159
+ attn_w_pos[i * attn_w_stride] = 0 ;
160
+ reduce_head<T>(
161
+ q_ptr_start + i * head_size,
162
+ k_ptr_start,
163
+ attn_w_pos + i * attn_w_stride,
164
+ head_size,
165
+ store_key,
166
+ k_cache_start);
167
+ }
168
+ }
169
+
148
170
/*
149
171
*reduce the attention_weights with the value embedding by the dimension of
150
172
*head_size for every head
@@ -170,6 +192,32 @@ void mul_attenion_weights_and_value_of_head(
170
192
}
171
193
}
172
194
195
+ template <typename T, typename T1>
196
+ void mul_attenion_weights_and_value_of_head (
197
+ float * attn_w,
198
+ int attn_w_stride,
199
+ const T* v_ptr_start,
200
+ T1* attn_out_start,
201
+ int attn_out_strideH,
202
+ int kv_head_group_size,
203
+ int64_t head_size,
204
+ bool store_value,
205
+ T* v_cache_start,
206
+ uint8_t * flag_access) {
207
+ for (auto i = 0 ; i < kv_head_group_size; i++) {
208
+ mul_attenion_weights_and_value_of_head<T, T1>(
209
+ attn_w[i * attn_w_stride],
210
+ v_ptr_start,
211
+ attn_out_start + i * attn_out_strideH,
212
+ head_size,
213
+ store_value,
214
+ v_cache_start,
215
+ flag_access[i]);
216
+ if (flag_access[i] == 0 )
217
+ flag_access[i] = 1 ;
218
+ }
219
+ }
220
+
173
221
#if defined(CPU_CAPABILITY_AVX512)
174
222
template <>
175
223
void mul_attenion_weights_and_value_of_head (
@@ -594,17 +642,21 @@ scale_dot_product_for_indirect_access_kv_cache(
594
642
#pragma omp parallel for collapse(3)
595
643
for (auto block_id = 0 ; block_id < kv_block_count; block_id++) {
596
644
for (auto bi = 0 ; bi < bs; bi++) {
597
- for (auto hi = 0 ; hi < head_num; hi++) {
645
+ for (auto head_group_start = 0 ; head_group_start < head_num;
646
+ head_group_start += group_size) {
598
647
auto k_start = block_id * kv_block_size;
599
648
auto block_size = std::min (kv_block_size, seq_len - k_start);
600
649
auto query_ti = 0 ;
601
650
for (auto ti = k_start; ti < k_start + block_size; ti++) {
602
- auto kv_hi = hi / group_size; // maping the query head to
603
- // key/value head to support MGA/MQA
651
+ auto kv_hi = head_group_start /
652
+ group_size; // maping the query head to
653
+ // key/value head to support MGA/MQA
604
654
auto q_ptr_start = q_ptr +
605
655
(bi * cur_len + query_ti) * head_num * head_size +
606
- hi * head_size;
607
- auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
656
+ head_group_start * head_size;
657
+ auto attn_w_stride2 = cur_len * seq_len;
658
+ auto attn_w_stride =
659
+ (bi * head_num + head_group_start) * attn_w_stride2;
608
660
auto attn_w_pos =
609
661
attn_w_ptr + attn_w_stride + query_ti * seq_len + ti;
610
662
attn_w_pos[0 ] = 0 .0f ;
@@ -632,8 +684,10 @@ scale_dot_product_for_indirect_access_kv_cache(
632
684
kv_hi * head_size;
633
685
reduce_head<QT>(
634
686
q_ptr_start,
687
+ group_size,
635
688
k_ptr_start,
636
689
attn_w_pos,
690
+ attn_w_stride2,
637
691
head_size,
638
692
true ,
639
693
kc_head_start);
@@ -644,8 +698,10 @@ scale_dot_product_for_indirect_access_kv_cache(
644
698
kv_hi * head_size;
645
699
reduce_head<QT>(
646
700
q_ptr_start,
701
+ group_size,
647
702
k_ptr_start,
648
703
attn_w_pos,
704
+ attn_w_stride2,
649
705
head_size,
650
706
false ,
651
707
nullptr );
@@ -662,8 +718,10 @@ scale_dot_product_for_indirect_access_kv_cache(
662
718
k_cache_ptr + kc_t_beam_start + kv_hi * head_size;
663
719
reduce_head<QT>(
664
720
q_ptr_start,
721
+ group_size,
665
722
kc_head_start,
666
723
attn_w_pos,
724
+ attn_w_stride2,
667
725
head_size,
668
726
false ,
669
727
nullptr );
@@ -737,6 +795,7 @@ scale_dot_product_for_indirect_access_kv_cache(
737
795
auto private_attn_out_flag =
738
796
at::zeros ({thread_numbers, bs, head_num}, at::kByte );
739
797
auto flag_access = private_attn_out_flag.accessor <uint8_t , 3 >();
798
+ uint8_t * flag_access_ptr = flag_access.data ();
740
799
auto private_attn_out_ptr = private_attn_outs.data_ptr <float >();
741
800
// private_attn_outs.numel());
742
801
auto attn_outs_stride_priv = bs * head_num * cur_len * head_size;
@@ -747,7 +806,7 @@ scale_dot_product_for_indirect_access_kv_cache(
747
806
#pragma omp parallel for collapse(3)
748
807
for (auto block_id = 0 ; block_id < kv_block_count; block_id++) {
749
808
for (auto bi = 0 ; bi < bs; bi++) {
750
- for (auto hi = 0 ; hi < head_num; hi++ ) {
809
+ for (auto hi = 0 ; hi < head_num; hi += group_size ) {
751
810
auto thread_id = 0 ;
752
811
if (kv_block_size < seq_len)
753
812
thread_id = omp_get_thread_num ();
@@ -757,15 +816,19 @@ scale_dot_product_for_indirect_access_kv_cache(
757
816
for (auto vi = v_start; vi < v_start + block_size; vi++) {
758
817
auto kv_hi = hi / group_size; // maping the query head to
759
818
// key/value head to support MGA/MQA
760
- auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
819
+ auto attn_w_stride2 = cur_len * seq_len;
820
+ auto attn_w_stride = (bi * head_num + hi) * attn_w_stride2;
761
821
auto attn_w_query_start =
762
- attn_w_ptr + attn_w_stride + query_ti * seq_len;
822
+ attn_w_ptr + attn_w_stride + query_ti * seq_len + vi ;
763
823
// calculate weighted value and store the result to attn_outs[bs,
764
824
// head_num, cur_len, head_size]
825
+ auto attn_out_head_stride2 = cur_len * head_size;
765
826
auto attn_out_head_stride = thread_id * attn_outs_stride_priv +
766
- (bi * head_num + hi) * cur_len * head_size ;
827
+ (bi * head_num + hi) * attn_out_head_stride2 ;
767
828
auto attn_out_start = private_attn_out_ptr + attn_out_head_stride +
768
829
query_ti * head_size;
830
+ auto flag_access_start = flag_access_ptr +
831
+ head_num * bs * thread_id + head_num * bi + hi;
769
832
770
833
auto vc_token_start = vi * kc_token_stride;
771
834
auto beam = need_update_beam_idx ? new_beam_idx[bi][vi] : 0 ;
@@ -787,13 +850,16 @@ scale_dot_product_for_indirect_access_kv_cache(
787
850
(bi * cur_len + vi - offset) * kv_head * head_size +
788
851
kv_hi * head_size;
789
852
mul_attenion_weights_and_value_of_head<VT, float >(
790
- attn_w_query_start[vi],
853
+ attn_w_query_start,
854
+ attn_w_stride2,
791
855
v_ptr_start,
792
856
attn_out_start,
857
+ attn_out_head_stride2,
858
+ group_size,
793
859
head_size,
794
860
true ,
795
861
v_cache_head_start,
796
- flag_access[thread_id][bi][hi] );
862
+ flag_access_start );
797
863
} else if (vi < query_ti + offset) { // caculate attention
798
864
// values for the past
799
865
// token
@@ -802,13 +868,16 @@ scale_dot_product_for_indirect_access_kv_cache(
802
868
(bi * cur_len + vi - offset) * kv_head * head_size +
803
869
kv_hi * head_size;
804
870
mul_attenion_weights_and_value_of_head<VT, float >(
805
- attn_w_query_start[vi],
871
+ attn_w_query_start,
872
+ attn_w_stride2,
806
873
v_ptr_start,
807
874
attn_out_start,
875
+ attn_out_head_stride2,
876
+ group_size,
808
877
head_size,
809
878
false ,
810
879
nullptr ,
811
- flag_access[thread_id][bi][hi] );
880
+ flag_access_start );
812
881
} else {
813
882
auto vc_t_beam_start =
814
883
vc_token_start + beam * kv_head * head_size;
@@ -822,13 +891,16 @@ scale_dot_product_for_indirect_access_kv_cache(
822
891
auto v_cache_head_start =
823
892
v_cache_ptr + vc_t_beam_start + kv_hi * head_size;
824
893
mul_attenion_weights_and_value_of_head<VT, float >(
825
- attn_w_query_start[vi],
894
+ attn_w_query_start,
895
+ attn_w_stride2,
826
896
v_cache_head_start,
827
897
attn_out_start,
898
+ attn_out_head_stride2,
899
+ group_size,
828
900
head_size,
829
901
false ,
830
902
nullptr ,
831
- flag_access[thread_id][bi][hi] );
903
+ flag_access_start );
832
904
}
833
905
}
834
906
if (flag_access[thread_id][bi][hi] == 0 )
0 commit comments