-
Notifications
You must be signed in to change notification settings - Fork 506
/
Copy pathtest_pallas.py
1714 lines (1505 loc) · 68.2 KB
/
test_pallas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import logging
import sys
import unittest
from absl.testing import parameterized
import torch
from torch import nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla import runtime as xr
from torch_xla._internal import tpu
import numpy as np
if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def with_jax_high_precision(func):
def wrapper(*args, **kwargs):
jax.config.update('jax_default_matmul_precision', "highest")
try:
result = func(*args, **kwargs)
finally:
jax.config.update('jax_default_matmul_precision', "default")
return result
return wrapper
class PallasTest(parameterized.TestCase):
# This is to create a diagonal mask where only elements within the same segment
# can attend to each other. Since the mask is to mask out the unrelevant parts,
# therefore we use != instead of ==.
def _make_attention_mask_from_segment_ids(self, q_segment_ids,
kv_segment_ids):
return q_segment_ids.view(q_segment_ids.shape[0], 1, q_segment_ids.shape[1],
1) != kv_segment_ids.view(kv_segment_ids.shape[0],
1, 1,
kv_segment_ids.shape[1])
def _attention(self, q, k, v, *, attn_mask=None, ab=None):
attn_weight = q @ k.transpose(-2, -1)
if ab is not None:
attn_weight = attn_weight + ab
if attn_mask is not None:
attn_weight = attn_weight.masked_fill(attn_mask.bool(),
torch.finfo(attn_weight.dtype).min)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output
# The following helper functions prefixed with _pagedattention are used for PagedAttention unit tests
# Reference: https://github.com/google/jax/blob/main/tests/pallas/paged_attention_kernel_test.py
def _pagedattention_generate_qkv(
self,
seq_lens,
page_size,
max_seq_len,
num_kv_heads,
num_heads,
head_dim,
dtype=torch.float32,
query_len=None,
):
assert max_seq_len % page_size == 0
pages_per_sequence = max_seq_len // page_size
batch_size = len(seq_lens)
total_pages = batch_size * pages_per_sequence
k_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
v_pages = torch.randn(
num_kv_heads, total_pages, page_size, head_dim, dtype=dtype)
page_indices = torch.randperm(
batch_size * pages_per_sequence, dtype=torch.int32)
page_indices = page_indices.reshape(batch_size, pages_per_sequence)
if not query_len:
q = torch.randn(batch_size, num_heads, head_dim, dtype=dtype)
else:
q = torch.randn(batch_size, query_len, num_heads, head_dim, dtype=dtype)
return q, k_pages, v_pages, page_indices
def _ceil_div(self, a, b):
assert b != 0
return (a + b - 1) // b
def _ragged_pagedattention_generate_qkv(
self,
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype,
*,
max_num_batched_tokens=None,
max_num_seqs=16,
):
cu_q_lens = [0]
kv_lens = []
for q_len, kv_len in seq_lens:
assert q_len <= kv_len
cu_q_lens.append(cu_q_lens[-1] + q_len)
kv_lens.append(kv_len)
if max_num_batched_tokens is None:
max_num_batched_tokens = cu_q_lens[-1]
else:
max_num_batched_tokens = max(cu_q_lens[-1], max_num_batched_tokens)
if max_num_seqs is None:
max_num_seqs = len(seq_lens)
else:
max_num_seqs = max(len(seq_lens), max_num_seqs)
max_kv_len = max(kv_lens)
pages_per_seq = self._ceil_div(max_kv_len, page_size)
num_q_heads, num_kv_heads = num_heads
cu_q_lens = torch.tensor(cu_q_lens, dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
cu_q_lens = torch.nn.functional.pad(
cu_q_lens, (0, max_num_seqs + 1 - cu_q_lens.shape[0]), "constant", 0)
kv_lens = torch.nn.functional.pad(kv_lens,
(0, max_num_seqs - kv_lens.shape[0]),
"constant", 0)
q = torch.randn((max_num_batched_tokens, num_q_heads, head_dim),
dtype=dtype)
kv_pages = torch.randn((num_pages, page_size, num_kv_heads * 2, head_dim),
dtype=dtype)
page_indices = torch.randint(
0, num_pages, (max_num_seqs, pages_per_seq), dtype=torch.int32)
return q, kv_pages, kv_lens, page_indices, cu_q_lens
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add(self):
# This payload is generated by the following Pallas code:
# def add_vectors_kernel(x_ref, y_ref, o_ref):
# x, y = x_ref[...], y_ref[...]
# o_ref[...] = x + y
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\", \"needs_layout_passes\": true}}"
x = torch.arange(8, dtype=torch.int).to("xla")
y = torch.arange(8, dtype=torch.int).to("xla")
expected_output = x + y
output = torch_xla._XLAC._xla_tpu_custom_call([x, y], payload, [x.shape],
[x.dtype])
self.assertTrue(torch.allclose(output[0].cpu(), expected_output.cpu()))
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add_one(self):
# This payload is generated by the following Pallas code:
# def add_vectors_kernel(x_ref, o_ref):
# o_ref[...] = x_ref[...] + 1
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==\", \"needs_layout_passes\": true}}"
x = torch.arange(8, dtype=torch.int).to("xla")
expected_output = x + 1
output = torch_xla._XLAC._xla_tpu_custom_call([x], payload, [x.shape],
[x.dtype])
self.assertTrue(torch.allclose(output[0].cpu(), expected_output.cpu()))
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_raise(self):
# This payload is generated by the following Pallas code:
# def add_vectors_kernel(x_ref, o_ref):
# o_ref[...] = x_ref[...] + 1
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==\", \"needs_layout_passes\": true}}"
# _xla_tpu_custom_call requires at least one input.
with self.assertRaises(RuntimeError):
torch_xla._XLAC._xla_tpu_custom_call([], payload, [(8, 1)], [torch.int32])
output.cpu()
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_flash_attention(self):
# This payload is generated by the following Pallas code:
# https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# To be noted, set `jax.config.update('jax_default_matmul_precision', 'highest')`` before generating the payload.
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMjAuMC4wZ2l0AAEvCwEDBQcJAQMLAxkNDxETFRcZGx0fISMD0gJaAhsB9QcTCwsPExMLDw8TCwsLC5MLCw8TDwsLCwsLDwsLCw8PCwsPCw8PExMXExMTCw9DCxsLxQuTCwsLCxsbCxsLGwsbCxsbGxsPDw8PFw8LFw8PCxcPDwsXDw8LFw8PCxcPCxcPDxcTHwsPDxcjDxMfCw8XGw8PCw8XCw8LBQmNeZFhBwNdCQNZASsXHwsTFx8PCxMTFxMTFxcfCxMXIwsHA0kBGw8HKx8bBxcPIwsbLy8C0g0fAwMPjwUlBScVl50DAw+NHVICVQUpHSORHSPDHSMuAgUrBS0FLwUxIw8JQQEAAAAAAAAAAQAAAAAAAACAAAAAAAAAAAQAAAAAAAAADRkFMxETAAMD7/8RDwEFNQU3BTkFOwU9Hc3PBT8FQQVDHd0/Fd8JBUUFRwED5QVJHelLFesJHQoCTxUOAgkdHgIiAh1CAlUVRgIJAwNZWwVLEQ8JAw9fYRdjZ2lrKW0pGW9xcwVNAQn19fX5DRdhZmZpbmVfbWFwPChkMCwgZDEsIGQyLCBkMykgLT4gKGQwLCBkMSwgZDIsIGQzKT4ABU8jDwlBAwAAAAAAAAACAAAAAAAAAAEAAAAAAAAAAQAAAAAAAAAFUQVTBVUFVwEJdXl9gQMFG3cdHwkrAwUbex0fCS0DBRt/HR8JLwMFG4MdHwkxAwUXIRkrAwUXIRktAwUXIRkvAwUXIRkxEQEBEQMBFZMJHQeVFwUGCAEdmZsFWRcFSgUBFZ+lHaGjBVsXBYYLARWnrR2pqwVdFwViAwEVr7UdsbMFXxcFGgMBFbe9Hbm7BWEXM14DAR2/wQVjFzM2EAEVxQkdB8cXBQoIAQMDD8slBwkAAAAABWUV0QkdB9MXBQ4IAQMHN/c5JTvXERMBAwMP2yUNCQAAgP8FZx0H4RcFoggBAwVB/UNFEQ8FHUc/BWkdB+0XBaYIAQVrHfNLBW0jdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8cGFyYWxsZWw+ACN0cHUuY29udHJhY3RfcHJlY2lzaW9uPGZwMzI+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+AAMDDwYCJQ0JAAAAAAVvHQcSAhcFqggBAwVBVgJDRR1HTwVxFSYCCR0HKgIXBa4IARUyAgkdBzYCFwXKCAEDAw8+AiUJCQAAAAAFcx0HSgIXBc4IAQMHN/c5JTslBXUjdmVjdG9yLmtpbmQ8YWRkPgABAgIDF/sJBQUCBBELZScFAgQCBAsnBQIEEQsLJwMCBAsBAgQnCQUFAgQRCwEJJwUCBAULBREBAQEBBQUFBQEFCQEBAQEJAQEBAQSuBwUBEQFXBwMBFQcRAV0HA2GrEQEBAQEBAQEBBQEFAQUBBQEDAxEDAwMDAxEDAwMDAxEDAwMDAxEDAwMLBhEDEQsJERMVFwUGEQMJAxkDAxMDAwMDAxMDAwMDAxMDAwMDAxMDAwMLBhMDEQsLHR8hIwUGEwMJAyUDAzXJAwcNBzXVAwcHGycpAwM92QMNDwc94wMNBSstBQbnAxUDLxEGSQMHAzETB0knAwcFKzMVB/EnAwcDNQMDTQICAw0PB00WAgMNBTc5BQYaAgMVAzsRBlEDBwM9FwdRJwMHBTc/AwMVAwMDAwMVAwMDAwMVAwMDAwMVAwMDCwYVAxELDUNFR0kFBhUDCQNLAwNTOgIDCQ0HU04CAwkHQU1PAwMNAwMDAwMNAwMDAwMNAwMDAwMNAwMDCwYNAxELD1NVV1kFBg0DCQNbBQYNAxEDURkEDQ1fD1NVV1kJAAEHEQGFBwMNDwkBAQEBAQEBAQMDAQsDAQMDAQsDAQkEAQkBAwUJBxEBhwcDDQ8JAQEBAQEBAQEDAwELAwEDAwELAwEJBAEJAQMHCQcRAYkHAw0PCQEBAQEBAQEBAwMBCwMBAwMBCwMBCQQBCQEDBwkHEQGLBwMNDwkBAQEBAQEBAQMDAQsDAQMDAQsDAQkEAQkBAwUJBgMBBQEAMhp37gImAgsvCxMLLyYCE2MhIy0xHQsjISMpLXkfCx0dFVkZGRkZ6gIdJRMdDWPvGxcTFyMvFxkZFSUfDw0PCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQB2ZWN0b3IAYXJpdGgAbW9kdWxlAGFyaXRoLmNvbnN0YW50AHZlY3Rvci5zaGFwZV9jYXN0AGZ1bmMuZnVuYwBmdW5jLnJldHVybgB2ZWN0b3IubG9hZAB0cHUubWF0bXVsAHZlY3Rvci5tdWx0aV9yZWR1Y3Rpb24AdmVjdG9yLmJyb2FkY2FzdABhcml0aC5zdWJmAG1hdGguZXhwAGFyaXRoLmRpdmYAdmVjdG9yLnN0b3JlAC9ob21lL2JiYWhsL21pbmljb25kYTMvZW52cy90b3JjaHNlcDEwL2xpYi9weXRob24zLjEwL3NpdGUtcGFja2FnZXMvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9mbGFzaF9hdHRlbnRpb24ucHkAX2ZsYXNoX2F0dGVudGlvbl9rZXJuZWxfc2luZ2xlX2JhdGNoX3NpbmdsZV9zdGVwAHZhbHVlAGZ1bmN0aW9uX3R5cGUAc3ltX25hbWUAdHJhbnNmb3JtX2luZGljZXMAd2luZG93X2JvdW5kcwAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKCgqLCAqLCBDdXN0b21Ob2RlKFNsaWNlWygwLCAxMjgsIDEpXSwgW05vbmUsIE5vbmVdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCwgMSldLCBbTm9uZSwgTm9uZV0pKSksICgxLCAxLCAxMjgsIDQpLCAoKSldLCBbKiwgKl0pLCkpXQB0cmFuc2Zvcm1fMAB0cmFuc2Zvcm1fMQB0cmFuc2Zvcm1fMgB0cmFuc2Zvcm1fMwAvaG9tZS9iYmFobC9weXRvcmNoL3hsYS90ZXN0L3Rlc3RfcGFsbGFzLnB5AHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBzdGFibGVfbW9zYWljLnZlcnNpb24AZGltZW5zaW9uX3NlbWFudGljcwBpdGVyYXRpb25fYm91bmRzAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAG1haW4Ad2luZG93X3BhcmFtcwBfZmxhc2hfYXR0ZW50aW9uX2tlcm5lbABfZmxhc2hfYXR0ZW50aW9uX2ltcGwAX2ZsYXNoX2F0dGVudGlvbgBmbGFzaF9hdHRlbnRpb24AdGVzdF90cHVfY3VzdG9tX2NhbGxfcGFsbGFzX3dyYXBfZmxhc2hfYXR0ZW50aW9uADxtb2R1bGU+AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPShQcmVjaXNpb24uSElHSEVTVCwgUHJlY2lzaW9uLkhJR0hFU1QpIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3JlZHVjZV9tYXhbYXhlcz0oMSwpXQAvc3ViAGZhc3RtYXRoAC9leHAAL3JlZHVjZV9zdW1bYXhlcz0oMSwpXQAvZGl2AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgwLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPShQcmVjaXNpb24uSElHSEVTVCwgUHJlY2lzaW9uLkhJR0hFU1QpIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKCosICosIEN1c3RvbU5vZGUoU2xpY2VbKDAsIDEyOCwgMSldLCBbTm9uZSwgTm9uZV0pLCBDdXN0b21Ob2RlKFNsaWNlWygwLCA0LCAxKV0sIFtOb25lLCBOb25lXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"serialization_format\": 1, \"needs_layout_passes\": true}, \"implicit_sharding\": {\"type\": \"MANUAL\"}}"
# The division is to cause potential precision issue on TPU.
q_mini = torch.arange(128 * 4, dtype=torch.float32).reshape(128, 4) / 13
k_mini = torch.arange(
1000, 1000 + 128 * 4, dtype=torch.float32).reshape(128, 4) / 13
q = q_mini.broadcast_to(3, 2, 128, 4).to("xla")
k = k_mini.broadcast_to(3, 2, 128, 4).to("xla")
v = torch.ones(3, 2, 128, 4).to("xla")
expected_o = self._attention(q, k, v)
o = torch_xla._XLAC._xla_tpu_custom_call([q, k, v], payload, [q.shape],
[q.dtype])
self.assertTrue(torch.allclose(o[0].cpu(), expected_o.cpu()))
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_extract_add_payload(self):
import jax._src.pallas.mosaic.pallas_call_registration
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape,
x.dtype))(x, y)
import torch_xla.experimental.custom_kernel as custom_kernel
ir = jax.jit(add_vectors).lower(jnp.arange(8), jnp.arange(8)).compiler_ir()
payload = custom_kernel._extract_backend_config(ir)
# The payload being generated could vary each time. We just want to make sure
# the most important fields are present.
self.assertIn("custom_call_config", payload)
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_wrap_add_payload(self):
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape,
x.dtype))(x, y)
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors,
lambda x, y: [(x.shape, x.dtype)])
dtypes = [
torch.float32, torch.float
] # Add doesn't support torch.float64, torch.bfloat16, torch.float16.
for i in range(len(dtypes)):
x = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla")
y = torch.randn((i + 1, i + 1), dtype=dtypes[i]).to("xla")
expected_output = x + y
output = pt_kernel(x, y)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))
dtypes = [
torch.int32, torch.int
] # Add doesn't support torch.int64, torch.int16, torch.int8, torch.uint8.
for i in range(len(dtypes)):
x = torch.arange(i + 1, dtype=dtypes[i]).to("xla")
y = torch.arange(i + 1, dtype=dtypes[i]).to("xla")
expected_output = x + y
output = pt_kernel(x, y)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_tpu_custom_call_pallas_wrap_flash_attention(self):
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
flash_attention_kernel = make_kernel_from_pallas(
flash_attention, lambda q, k, v: [(q.shape, q.dtype)])
q_mini = torch.arange(128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
k_mini = torch.arange(
1000, 1000 + 128 * 4, dtype=torch.bfloat16).reshape(128, 4) / 13
q = q_mini.broadcast_to(3, 2, 128, 4).to("xla")
k = k_mini.broadcast_to(3, 2, 128, 4).to("xla")
v = torch.ones(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")
o = flash_attention_kernel(q, k, v)
expected_o = self._attention(q, k, v)
torch.testing.assert_close(o.cpu(), expected_o.cpu())
# self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper(self):
from torch_xla.experimental.custom_kernel import flash_attention
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
o = flash_attention(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_kv_and_ab_padding(self):
from torch_xla.experimental.custom_kernel import flash_attention
q = torch.randn(1, 2, 513, 4).to("xla")
k = torch.randn(1, 2, 513, 4).to("xla")
v = torch.randn(1, 2, 513, 4).to("xla")
ab = torch.randn(1, 2, 513, 513).to("xla")
o = flash_attention(q, k, v, ab=ab)
expected_o = self._attention(q, k, v, ab=ab)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_with_dynamo(self):
from torch_xla.experimental.custom_kernel import flash_attention
def flash_attention_wrapper(q, k, v, causal=False):
return torch.ops.xla.flash_attention(q, k, v, causal)
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
compiled_flash_attention = torch.compile(
flash_attention_wrapper, backend="openxla")
o_no_causal = compiled_flash_attention(q, k, v)
o_with_causal = compiled_flash_attention(q, k, v, causal=True)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o_no_causal.cpu(), expected_o.cpu()))
# The causal mask is turned on by default in the wrapper.
# It masks out the top right triangle of the attention matrix,
# therefore it speeds up the compute but also changes the output.
self.assertFalse(
torch.allclose(o_with_causal.cpu(), expected_o.cpu(), atol=1e-05))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_wrapper_causal(self):
from torch_xla.experimental.custom_kernel import flash_attention
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
# The causal mask is turned on by default in the wrapper.
# It masks out the top right triangle of the attention matrix, therefore it speeds up the compute but also changes the output.
o = flash_attention(q, k, v, causal=True)
expected_o = self._attention(q, k, v)
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_multiple_returns(self):
import jax._src.pallas.mosaic.pallas_call_registration
def add_minus_vectors_kernel(x_ref, y_ref, o1_ref, o2_ref):
x, y = x_ref[...], y_ref[...]
o1_ref[...] = x + y
o2_ref[...] = x - y
@jax.jit
def add_minus_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return pl.pallas_call(
add_minus_vectors_kernel, out_shape=[out_shape, out_shape])(x, y)
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(
add_minus_vectors, lambda x, y: [(x.shape, x.dtype),
(x.shape, x.dtype)])
x = torch.arange(8, device="xla", dtype=torch.float)
o = pt_kernel(x, x)
self.assertEqual(len(o), 2)
expected_o0 = x + x
expected_o1 = x - x
self.assertTrue(torch.allclose(o[0].cpu(), expected_o0.cpu()))
self.assertTrue(torch.allclose(o[1].cpu(), expected_o1.cpu()))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test__flash_attention_impl(self):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
MIN_BLOCK_SIZE = 128
def shape_dtype(q, *arg):
res_shape = list(q.shape)
res_shape[-1] = MIN_BLOCK_SIZE
return [(q.shape, q.dtype), (res_shape, torch.float32),
(res_shape, torch.float32)]
flash_attention_kernel = make_kernel_from_pallas(_flash_attention_impl,
shape_dtype)
q = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")
k = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")
v = torch.randn(3, 2, 128, 4, dtype=torch.bfloat16).to("xla")
o, l, m = flash_attention_kernel(
q,
k,
v,
None,
None,
True,
False,
1.0,
2,
128,
128,
128,
False,
static_argnums=range(5, 13))
xm.mark_step()
# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(l.shape, (3, 2, 128, MIN_BLOCK_SIZE))
self.assertEqual(l.dtype, torch.float32)
self.assertEqual(m.shape, (3, 2, 128, MIN_BLOCK_SIZE))
self.assertEqual(m.dtype, torch.float32)
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test__flash_attention_bwd_dkv(self):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dkv
from torch_xla.experimental.custom_kernel import trace_pallas
MIN_BLOCK_SIZE = 128
DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
l = torch.randn(3, 2, 128).to("xla")
m = torch.randn(3, 2, 128).to("xla")
grad_i = torch.randn(3, 2, 128, dtype=torch.float32).to("xla")
grad_o = torch.randn(3, 2, 128, 4).to("xla")
payload, _ = trace_pallas(
_flash_attention_bwd_dkv,
q,
k,
v,
None,
None,
l,
m,
grad_o,
grad_i,
block_q_major=128,
block_k_major=128,
block_k=128,
block_q=128,
sm_scale=1.0,
causal=False,
mask_value=DEFAULT_MASK_VALUE,
debug=False,
static_argnames=[
"block_q_major", "block_k_major", "block_k", "block_q", "sm_scale",
"causal", "mask_value", "debug"
])
# TODO: Because of the following reshapes, we can't use make_kernel_from_pallas directly.
l = l.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
m = m.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
grad_i = grad_i.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
output = torch_xla._XLAC._xla_tpu_custom_call(
[q, k, v, l, m, grad_o, grad_i], payload, [k.shape, v.shape],
[k.dtype, v.dtype])
xm.mark_step()
# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(output[0].shape, (3, 2, 128, 4))
self.assertEqual(output[1].shape, (3, 2, 128, 4))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test__flash_attention_bwd_dkv(self):
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq
from torch_xla.experimental.custom_kernel import trace_pallas
MIN_BLOCK_SIZE = 128
DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)
q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
l = torch.randn(3, 2, 128).to("xla")
m = torch.randn(3, 2, 128).to("xla")
grad_i = torch.randn(3, 2, 128, dtype=torch.float32).to("xla")
grad_o = torch.randn(3, 2, 128, 4).to("xla")
payload, _ = trace_pallas(
_flash_attention_bwd_dq,
q,
k,
v,
None,
None,
l,
m,
grad_o,
grad_i,
block_q_major=128,
block_k_major=128,
block_k=128,
sm_scale=1.0,
causal=False,
mask_value=DEFAULT_MASK_VALUE,
debug=False,
static_argnames=[
"block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
"mask_value", "debug"
])
# TODO: Because of the following reshapes, we can't use make_kernel_from_pallas directly.
l = l.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
m = m.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
grad_i = grad_i.unsqueeze(-1).expand(3, 2, 128, MIN_BLOCK_SIZE)
output = torch_xla._XLAC._xla_tpu_custom_call(
[q, k, v, l, m, grad_o, grad_i], payload, [q.shape], [q.dtype])
xm.mark_step()
# TODO: I don't really know how to test the value. Let's do the shape check for now.
self.assertEqual(output[0].shape, (3, 2, 128, 4))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@with_jax_high_precision
def test_flash_attention_backward(self):
from torch_xla.experimental.custom_kernel import flash_attention
torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
o = flash_attention(q, k, v)
loss = o.sum()
loss.backward()
xm.mark_step()
q_grad = q.grad
k_grad = k.grad
v_grad = v.grad
torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
o = self._attention(q, k, v)
loss = o.sum()
loss.backward()
xm.mark_step()
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_wrapper(self):
from torch_xla.experimental.custom_kernel import paged_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as jax_paged_attention
max_kv_len = 2048
block_size = 512
page_size = 64
num_kv_heads = 8
q_kv_head_ratio = 8
head_dim = 256
dtype = torch.float32
seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32)
q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
)
q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
seq_lens_xla = seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")
output = paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
seq_lens_xla,
page_indices_xla,
pages_per_compute_block=block_size // page_size,
)
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
seq_lens_jax = jnp.array(seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
seq_lens_jax,
page_indices_jax,
pages_per_compute_block=block_size // page_size,
)))
self.assertTrue(
torch.allclose(
output.cpu()[seq_lens > 0],
expected_output.cpu()[seq_lens > 0],
atol=1e-5,
rtol=1e-5))
def _test_ragged_paged_attention(
self,
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype,
*,
sm_scale=1.0,
sliding_window=None,
soft_cap=None,
num_kv_pages_per_block=16,
num_queries_per_block=128,
pad_tokens_and_seqs=False,
use_dynamo=True,
):
num_seqs = len(seq_lens)
max_num_batched_tokens = None
max_num_seqs = None
if pad_tokens_and_seqs:
max_num_batched_tokens = 1024
max_num_seqs = 16
q, kv_pages, kv_lens, page_indices, cu_q_lens = self._ragged_pagedattention_generate_qkv(
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs)
q_xla = q.to("xla")
kv_pages_xla = kv_pages.to("xla")
kv_lens_xla = kv_lens.to("xla")
page_indices_xla = page_indices.to("xla")
cu_q_lens_xla = cu_q_lens.to("xla")
num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla")
if use_dynamo:
def ragged_paged_attention_wrapper(
q,
kv_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
use_kernel=True,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
):
return torch.ops.xla.ragged_paged_attention(
q,
kv_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
use_kernel=use_kernel,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)
attn = torch.compile(ragged_paged_attention_wrapper, backend="openxla")
else:
from torch_xla.experimental.custom_kernel import ragged_paged_attention
attn = ragged_paged_attention
kernel_output = attn(
q_xla,
kv_pages_xla,
kv_lens_xla,
page_indices_xla,
cu_q_lens_xla,
num_seqs_xla,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
use_kernel=True,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)[:cu_q_lens[num_seqs]]
nonkernel_output = attn(
q_xla,
kv_pages_xla,
kv_lens_xla,
page_indices_xla,
cu_q_lens_xla,
num_seqs_xla,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
use_kernel=False,
)
kernel_output_cpu = kernel_output.cpu()
nonkernel_output_cpu = nonkernel_output.cpu()
self.assertEqual(kernel_output_cpu.shape, nonkernel_output_cpu.shape)
self.assertEqual(kernel_output_cpu.dtype, nonkernel_output_cpu.dtype)
assert dtype == torch.float32 or dtype == torch.bfloat16
jnp_dtype = jnp.float32
tol = 0.15
if dtype == torch.bfloat16:
jnp_dtype = jnp.bfloat16
tol = 0.3
# Numpy does not support bfloat16 directly. So we convert f32 first.
q_jax = jnp.array(q.to(torch.float32).numpy(), dtype=jnp_dtype)
kv_pages_jax = jnp.array(
kv_pages.to(torch.float32).numpy(), dtype=jnp_dtype)
kv_lens_jax = jnp.array(kv_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32)
num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32)
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size
if num_kv_pages_per_block is None:
assert num_queries_per_block is None
token_num, q_head_num, _ = q.shape
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
_, pages_per_seq = page_indices.shape
num_kv_heads = num_combined_kv_heads // 2
max_model_len = pages_per_seq * page_size
num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size(
q_head_num, num_kv_heads, token_num, max_model_len)
jax_kernel_output = torch.from_numpy(
np.array(
jax_ragged_paged_attention(
q_jax,
kv_pages_jax,
kv_lens_jax,
page_indices_jax,
cu_q_lens_jax,
num_seqs=num_seqs_jax,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)[:cu_q_lens[num_seqs]].astype(jnp.float32))).to(dtype)
jax_kernel_output_cpu = jax_kernel_output.cpu()
torch.testing.assert_close(
kernel_output_cpu, nonkernel_output_cpu, atol=tol, rtol=tol)
torch.testing.assert_close(
kernel_output_cpu, jax_kernel_output_cpu, atol=tol, rtol=tol)
@parameterized.product(
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
num_heads=[(32, 8), (8, 1)],
dtype=[torch.float32, torch.bfloat16],
sm_scale=[1.0, 0.5],
sliding_window=[None, 128],
soft_cap=[None, 10.0],
pad_tokens_and_seqs=[False, True],
block_sizes=[(16, 128), (None, None)])
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_wrapper_with_dynamo(
self,
seq_lens,
num_heads,
dtype,
sm_scale,
sliding_window,
soft_cap,
pad_tokens_and_seqs,
block_sizes,
):
head_dim = 128
page_size = 16
num_pages = 1000
num_kv_pages_per_block, num_queries_per_block = block_sizes
self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
pad_tokens_and_seqs=pad_tokens_and_seqs,
use_dynamo=True,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)
@parameterized.product(
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
num_heads=[(32, 8), (8, 1)],
dtype=[torch.float32, torch.bfloat16],
sm_scale=[1.0, 0.5],
sliding_window=[None, 128],
soft_cap=[None, 10.0],
pad_tokens_and_seqs=[False, True],
block_sizes=[(16, 128), (None, None)],
)
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_ragged_paged_attention_wrapper_without_dynamo(
self,
seq_lens,
num_heads,
dtype,
sm_scale,
sliding_window,
soft_cap,
pad_tokens_and_seqs,
block_sizes,
):
head_dim = 128
page_size = 16
num_pages = 1000
num_kv_pages_per_block, num_queries_per_block = block_sizes
self._test_ragged_paged_attention(
seq_lens,
num_heads,
head_dim,
page_size,
num_pages,
dtype,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
pad_tokens_and_seqs=pad_tokens_and_seqs,
use_dynamo=False,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
)
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
"This test only works on TPUv4+.")
def test_paged_attention_multi_queries_wrapper(self):
from torch_xla.experimental.custom_kernel import multi_queries_paged_attention
from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention as jax_multi_queries_paged_attention
dtype = torch.float32
page_size = 16
num_kv_heads = 8
q_kv_head_ratio = 4
head_dim = 256
num_queries_per_compute_block = 32
block_kv_size = 256
max_kv_len = 2048
query_len = 64
batch_size = 3
kv_seq_lens = torch.randint(
query_len, max_kv_len, (batch_size,), dtype=torch.int32)
effective_q_lens = torch.full((batch_size,), query_len, dtype=torch.int32)
assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
batch_size = len(kv_seq_lens)
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size
q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv(
kv_seq_lens,
page_size,
max_kv_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
dtype=dtype,
query_len=query_len,
)
q_xla = q.to("xla")
k_pages_xla = k_pages.to("xla")
v_pages_xla = v_pages.to("xla")
kv_seq_lens_xla = kv_seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")
effective_q_lens_xla = effective_q_lens.to("xla")
output_no_cap = multi_queries_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)
output = multi_queries_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
attn_logits_soft_cap=1.0,
)
nonkernel_output = multi_queries_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
)
q_jax = jnp.array(q.numpy(), dtype=jnp.float32)
k_pages_jax = jnp.array(k_pages.numpy(), dtype=jnp.float32)
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
kv_seq_lens_jax = jnp.array(kv_seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
effective_q_lens_jax = jnp.array(effective_q_lens.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_multi_queries_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
kv_seq_lens_jax,
page_indices_jax,
effective_q_lens_jax,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
attn_logits_soft_cap=1.0,
)))
expected_output_no_cap = torch.from_numpy(
np.array(
jax_multi_queries_paged_attention(
q_jax,
k_pages_jax,
v_pages_jax,
kv_seq_lens_jax,
page_indices_jax,
effective_q_lens_jax,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)))
self.assertTrue(
torch.allclose(
output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5))
self.assertFalse(
torch.allclose(
output.cpu(), expected_output_no_cap.cpu(), atol=1e-5, rtol=1e-5))
self.assertTrue(
torch.allclose(
output_no_cap.cpu(),
expected_output_no_cap.cpu(),
atol=1e-5,
rtol=1e-5))
self.assertTrue(
torch.allclose(
output_no_cap.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2))