-
Notifications
You must be signed in to change notification settings - Fork 512
/
Copy pathcustom_kernel.py
1800 lines (1554 loc) · 67.8 KB
/
custom_kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import functools
import os
import math
import warnings
import torch
from torch.library import impl, custom_op
import torch_xla
from torch_xla.distributed.spmd import Mesh
import torch_xla.distributed.spmd as xs
from torch_xla._internal.jax_workarounds import requires_jax
# Re-expose this API used that is referenced by docs
from torch_xla._internal.jax_workarounds import jax_import_guard # noqa: F401, pylint: disable=unused-import
from typing import Any, List, Callable, Optional, Tuple, Dict
from torch_xla.core.xla_model import XLA_LIB
_XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"
DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)
def _shard_map(func, mesh, input_specs, output_specs):
"""Map a function over shards of data.
Note:
``shard_map`` is an experimental API, and still subject to change. For an
introduction to sharded data. For a more
in-depth look at using ``shard_map``, refer to
[SPMD multi-device parallelism with shard_map](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
Args:
func: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
takes as input a shard of the mapped-over arguments and produces a shard
of the output.
mesh: a ``Mesh`` representing the array of devices over which
to shard the data and on which to execute instances of ``f``. The names of
the ``Mesh`` can be used in collective communication operations in ``f``.
This is typically created by a utility function like
:func:`jax.experimental.mesh_utils.create_device_mesh`.
in_specs: a tuple of tuples of str. Each is the partition spec of positional input
of func. kwarg is not supported yet
out_specs: a pytree with :class:`~tuple[tuple[str]]`, with the same length
as the number of outputs
Returns:
A callable that applies the input function ``f`` across data sharded according to
the ``mesh`` and ``out_specs``.
Reference:
This function behaves identically Jax's shard_map:
https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html
"""
def _full_shape(a, spec):
# a is local tensor
# spec is the sharding spec
# return logical shape of global tensor
mesh_name_to_size = mesh.shape()
result_shape = []
for axis_size, axis_sharding in zip(a.shape, spec):
if axis_sharding is None:
axis_sharding = ()
mesh_mult = []
if isinstance(axis_sharding, (str, int)):
axis_sharding = [axis_sharding]
for axis in axis_sharding:
size = mesh_name_to_size[axis] or 1
mesh_mult.append(size)
new_size = axis_size * math.prod(mesh_mult)
result_shape.append(new_size)
return tuple(result_shape)
def wrapped(*args):
assert len(args) == len(
input_specs), f'args={len(args)}; input_specs={len(input_specs)}'
new_args = []
for i, (a, spec) in enumerate(zip(args, input_specs)):
if isinstance(a, torch.Tensor):
assert (len(a.shape) == len(spec)
), f'{i}th input has wrong shape: {a.shape} for {spec}'
new_a = xs.enable_manual_sharding(a, spec, mesh=mesh).global_tensor
new_args.append(new_a)
else:
new_args.append(a)
res = func(*new_args)
if isinstance(res, tuple):
res_updated = []
for i, (r, spec) in enumerate(zip(res, output_specs)):
if isinstance(r, torch.Tensor) and spec is not None:
assert str(r.device).startswith('xla'), f'{i}th device is {r.device}'
assert len(r.shape) == len(
spec), f'{i}th shape is {r.shape}, sharding is {output_specs[i]}'
new_r = xs.disable_manual_sharding(
r, spec, _full_shape(r, spec), mesh=mesh).global_tensor
else:
new_r = r
res_updated.append(new_r)
return res_updated
else:
return xs.disable_manual_sharding(
res, output_specs[0], _full_shape(res, output_specs[0]),
mesh=mesh).global_tensor
return wrapped
def safe_empty_like(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
"""Returns empty tensor like input, or None if input is None."""
return torch.empty_like(tensor) if tensor is not None else None
def generate_ctx_need_grad(*args):
ctx_need_grad = [False for _ in range(len(args))]
for i, arg in enumerate(args):
if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad:
ctx_need_grad[i] = True
return ctx_need_grad
def _extract_backend_config(
module: "jaxlib.mlir._mlir_libs._mlir.ir.Module") -> Optional[str]:
"""
This algorithm intends to extract the backend config from the compiler IR like the following,
and it is not designed to traverse any generic MLIR module.
module @jit_add_vectors attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<8xi32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = call @add_vectors(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}
func.func private @add_vectors(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
%0 = call @wrapped(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}
func.func private @wrapped(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
%0 = call @apply_kernel(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}
func.func private @apply_kernel(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
%0 = stablehlo.custom_call @tpu_custom_call(%arg0, %arg1) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSMTkuMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA3lZDQFVBwsPEw8PCw8PMwsLCwtlCwsLCwsPCw8PFw8LFw8PCxcPCxcTCw8LDxcLBQNhBwNZAQ0bBxMPGw8CagMfBRcdKy0DAycpHVMREQsBBRkVMzkVTw8DCxUXGRsfCyELIyUFGwEBBR0NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFHwUhBSMFJQUnEQMBBSkVLw8dDTEXA8IfAR01NwUrFwPWHwEVO0EdPT8FLRcD9h8BHUNFBS8XA3InAQMDSVcFMR1NEQUzHQ1RFwPGHwEFNSN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACNhcml0aC5vdmVyZmxvdzxub25lPgAXVQMhBx0DJwMhBwECAgUHAQEBAQECBASpBQEQAQcDAQUDEQETBwMVJwcBAQEBAQEHAwUHAwMLBgUDBQUBBwcDBQcDAwsGBQMFBQMLCQdLRwMFBQkNBwMJBwMDCwYJAwUFBRENBAkHDwURBQABBgMBBQEAxgg32wsdE2EZ2Q0LEyMhHSknaw0LCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAc3RvcmUAL3dvcmtzcGFjZXMvd29yay9weXRvcmNoL3hsYS90ZXN0L3Rlc3Rfb3BlcmF0aW9ucy5weQBhZGRfdmVjdG9yc19rZXJuZWwAZGltZW5zaW9uX3NlbWFudGljcwBmdW5jdGlvbl90eXBlAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAHN5bV9uYW1lAG1haW4AdmFsdWUAL2dldFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoQ3VzdG9tTm9kZShTbGljZVsoMCwgOCldLCBbXSksKSksICg4LCksICgpKV0sIFtdKSwpKV0AYWRkX3ZlY3RvcnMAdGVzdF90cHVfY3VzdG9tX2NhbGxfcGFsbGFzX2V4dHJhY3RfYWRkX3BheWxvYWQAPG1vZHVsZT4Ab3ZlcmZsb3dGbGFncwAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\22, \22needs_layout_passes\22: true}}", kernel_name = "add_vectors_kernel", operand_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}
}
Basically, what we are looking for is a two level of operations, and the tpu_custom_call operation in the inner level. It will return None if the payload is not found.
"""
for operation in module.body.operations:
assert len(
operation.body.blocks) == 1, "The passing module is not compatible."
for op in operation.body.blocks[0].operations:
if op.name == "stablehlo.custom_call":
return op.backend_config.value
return None
@requires_jax
def convert_torch_dtype_to_jax(dtype: torch.dtype) -> "jnp.dtype":
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
import jax.numpy as jnp
if _XLA_USE_BF16:
raise RuntimeError(
"Pallas kernel does not support XLA_USE_BF16, please unset the env var")
if dtype == torch.float32:
return jnp.float32
elif dtype == torch.float64:
return jnp.float64
elif dtype == torch.float16:
return jnp.float16
elif dtype == torch.bfloat16:
return jnp.bfloat16
elif dtype == torch.int32:
return jnp.int32
elif dtype == torch.int64:
return jnp.int64
elif dtype == torch.int16:
return jnp.int16
elif dtype == torch.int8:
return jnp.int8
elif dtype == torch.uint8:
return jnp.uint8
else:
raise ValueError(f"Unsupported dtype: {dtype}")
@requires_jax
def to_jax_shape_dtype_struct(tensor: torch.Tensor) -> "jax.ShapeDtypeStruct":
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
import jax
return jax.ShapeDtypeStruct(tensor.shape,
convert_torch_dtype_to_jax(tensor.dtype))
trace_pallas_arg_to_payload: Dict[Tuple[Any], str] = {}
@requires_jax
def trace_pallas(kernel: Callable,
*args,
static_argnums=None,
static_argnames=None,
use_cache=False,
**kwargs):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
import jax
import jax._src.pallas.mosaic.pallas_call_registration
jax_args = [] # for tracing
tensor_args = [] # for execution
for i, arg in enumerate(args):
# TODO: Could the args be a tuple of tensors or a list of tensors? Flattern them?
if torch.is_tensor(arg):
# ShapeDtypeStruct doesn't have any storage and thus is very suitable for generating the payload.
jax_meta_tensor = to_jax_shape_dtype_struct(arg)
jax_args.append(jax_meta_tensor)
tensor_args.append(arg)
else:
jax_args.append(arg)
hash_key = ()
if use_cache:
global trace_pallas_arg_to_payload
# implcit assumption here that everything in kwargs is hashable and not a tensor,
# which is true for the gmm and tgmm.
hash_key = (jax.config.jax_default_matmul_precision, kernel, static_argnums,
tuple(static_argnames)
if static_argnames is not None else static_argnames,
tuple(jax_args), repr(sorted(kwargs.items())).encode())
if hash_key in trace_pallas_arg_to_payload:
torch_xla._XLAC._xla_increment_counter('trace_pallas_cache_hit', 1)
return trace_pallas_arg_to_payload[hash_key], tensor_args
# Here we ignore the kwargs for execution as most of the time, the kwargs is only used in traced code.
ir = jax.jit(
kernel, static_argnums=static_argnums,
static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir()
payload = _extract_backend_config(ir)
if use_cache:
# if we reach here it means we have a cache miss.
trace_pallas_arg_to_payload[hash_key] = payload
return payload, tensor_args
def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
# TODO: Maybe we can cache the payload for the same input.
def wrapped_kernel(kernel: Callable,
output_shape_dtype_fn: Callable,
*args,
static_argnums=None,
static_argnames=None,
**kwargs) -> Callable:
payload, tensor_args = trace_pallas(
kernel,
*args,
static_argnums=static_argnums,
static_argnames=static_argnames,
**kwargs)
output_shape_dtype = output_shape_dtype_fn(*args)
assert isinstance(output_shape_dtype,
list), "The output_shape_dtype_fn should return a list."
output_shapes = [shape for shape, _ in output_shape_dtype]
output_dtypes = [dtype for _, dtype in output_shape_dtype]
outputs = torch_xla._XLAC._xla_tpu_custom_call(tensor_args, payload,
output_shapes, output_dtypes)
# Make the output easier to use.
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)
return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)
def _maybe_reshape_input_output_funcs(current_shape, non_batch_dims=3):
batch_dims = len(current_shape) - non_batch_dims
orig_batch_dims = current_shape[:batch_dims]
other_dims = current_shape[batch_dims:]
def reshape_input(tensor):
if tensor is None:
return None
return tensor.reshape(-1, *tensor.shape[batch_dims:])
def reshape_output(tensor):
if tensor is None:
return None
return tensor.reshape(*orig_batch_dims, *tensor.shape[1:])
return reshape_input, reshape_output
def _fa_custom_forward_single_device(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool,
q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, sm_scale: float,
ab: Optional[torch.Tensor],
ctx_grad: List[bool]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl
num_batches = None
batch_size = None
reshape_to_4d, undo_reshape = _maybe_reshape_input_output_funcs(q.shape, 3)
q = reshape_to_4d(q)
v = reshape_to_4d(v)
k = reshape_to_4d(k)
q_segment_ids = reshape_to_4d(q_segment_ids)
kv_segment_ids = reshape_to_4d(kv_segment_ids)
ab = reshape_to_4d(ab)
# Surprisingly, any tensor that is input to the custom_op decorated function will show
# requires_grad=False by design. We have to pass ctx_grad to record the
# requires_grad for inputs.
# Original we use save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
save_residuals = any(ctx_grad[:3])
block_k_major = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"],
k.shape[2])
block_k = min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2])
k, k_pad_size = _pad_to_block_size(k, max(block_k_major, block_k), 2)
if k_pad_size > 0:
v, _ = _pad_to_block_size(v, max(block_k_major, block_k), 2)
if ab is not None:
#ab = torch.zeros((q.shape[0], q.shape[1], q.shape[2], q.shape[2]), device=q.device)
ab, _ = _pad_to_block_size(
ab, max(block_k_major, block_k), 3, padding_minus_inf=True)
# It computes the shape and type of o, l, m.
shapes = [q.shape]
dtypes = [q.dtype]
if save_residuals:
res_shape = list(q.shape)
res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE
for _ in range(2):
shapes.append(res_shape)
dtypes.append(torch.float32)
with torch.no_grad():
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
# We can't directly use flash_attention as we need to override the save_residuals flag which returns
# l and m that is needed for the backward. Then we lose all the shape checks.
# TODO: replicate the shape checks on flash_attention.
# Here we seperate the tracing and execution part just to support SegmentIds.
payload, _ = trace_pallas(
_flash_attention_impl,
q,
k,
v,
ab,
segment_ids,
save_residuals,
causal,
sm_scale,
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]),
block_k_major,
block_k,
False,
static_argnums=range(5, 13),
use_cache=True,
)
args = [q, k, v]
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids_fa, kv_segment_ids_fa]
custom_call_output = torch_xla._XLAC._xla_tpu_custom_call(
args, payload, shapes, dtypes)
assert isinstance(custom_call_output, list)
if not save_residuals:
o = custom_call_output[0]
l = None
m = None
else:
o, *aux = custom_call_output
l, m = (v[..., 0] for v in aux[-2:])
o = undo_reshape(o)
l = undo_reshape(l)
m = undo_reshape(m)
return o, l, m
@custom_op("xla::fa_custom_forward", mutates_args=())
def fa_custom_forward(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool,
q_segment_ids: torch.Tensor, kv_segment_ids: torch.Tensor, sm_scale: float,
ab: Optional[torch.Tensor], partition_spec: str, mesh: str,
ctx_grad: List[bool]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
partition_spec = eval(partition_spec)
mesh = xs.get_global_mesh() or Mesh.from_str(mesh)
# Suprisingly, any tensor that is input to the custom_op decorated function will show
# requires_grad=False. Is this a bug or feature? We have to pass ctx_grad to record the
# requires_grad for inputs.
# Original we use save_residuals = q.requires_grad or k.requires_grad or v.requires_grad
save_residuals = any(ctx_grad[:3])
# SPMD integration.
# mark_sharding is in-placed, and therefore save the full q, k, v for the backward.
# PyTorch tell us clone is necessary:
full_q = q.clone()
full_k = k.clone()
full_v = v.clone()
if ab is not None:
full_ab = ab.clone()
else:
full_ab = None
if partition_spec is not None:
if len(partition_spec) == 5:
segment_id_partition_spec = (partition_spec[0], partition_spec[1],
partition_spec[3])
lm_partition_spec = partition_spec[:4]
else:
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
lm_partition_spec = partition_spec[:3]
input_specs = [
partition_spec, # q
partition_spec, # k
partition_spec, # v
None,
segment_id_partition_spec,
segment_id_partition_spec,
None,
partition_spec,
None,
]
output_specs = [
partition_spec, # o
lm_partition_spec, # l
lm_partition_spec, # m
]
fa_forward_callable = _shard_map(
_fa_custom_forward_single_device,
mesh,
input_specs,
output_specs,
)
else:
fa_forward_callable = _fa_custom_forward_single_device
o, l, m = fa_forward_callable(q, k, v, causal, q_segment_ids, kv_segment_ids,
sm_scale, ab, ctx_grad)
outs = [o] + [full_q, full_k, full_v, l, m, full_ab]
return tuple(outs)
def _pad_to_block_size(
tensor: torch.Tensor,
block_size: int,
dim: int,
padding_minus_inf: bool = False) -> Tuple[torch.Tensor, int]:
size = tensor.shape[dim]
if size % block_size == 0:
return tensor, 0
pad_size = block_size - (size % block_size)
pad_shape = list(tensor.shape)
pad_shape[dim] = pad_size
padding = torch.full(
pad_shape,
torch.finfo(tensor.dtype).min if padding_minus_inf else 0,
dtype=tensor.dtype,
device=tensor.device)
padded = torch.cat([tensor, padding], dim=dim)
return padded, pad_size
def _fa_custom_backward_single_device(
grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, o: torch.Tensor, l: torch.Tensor, m: torch.Tensor,
q_segment_ids: Optional[torch.Tensor],
kv_segment_ids: Optional[torch.Tensor], ab: Optional[torch.Tensor],
causal: bool, sm_scale: float, q_full_shape: List[int],
kv_full_shape: List[int], ab_full_shape: Optional[List[int]],
ctx_grad: List[bool]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv
grad_q = grad_k = grad_v = grad_ab = segment_ids = None
num_batches = None
batch_size = None
reshape_to_4d, undo_reshape = _maybe_reshape_input_output_funcs(q.shape, 3)
grad_output = reshape_to_4d(grad_output)
q = reshape_to_4d(q)
k = reshape_to_4d(k)
v = reshape_to_4d(v)
o = reshape_to_4d(o)
l = reshape_to_4d(l)
m = reshape_to_4d(m)
q_segment_ids = reshape_to_4d(q_segment_ids)
kv_segment_ids = reshape_to_4d(kv_segment_ids)
ab = reshape_to_4d(ab)
require_grad_q, require_grad_k, require_grad_v, *rest = ctx_grad
require_grad_ab = ctx_grad[-3]
q_full_shape = torch.Size(q_full_shape)
kv_full_shape = torch.Size(kv_full_shape)
ab_full_shape = torch.Size(
ab_full_shape) if ab_full_shape is not None else None
grad_i = torch.sum(
o.to(torch.float32) * grad_output.to(torch.float32),
axis=-1) # [batch_size, num_heads, q_seq_len]
expanded_l = l.unsqueeze(-1).expand([-1 for _ in l.shape] +
[FlashAttention.MIN_BLOCK_SIZE])
expanded_m = m.unsqueeze(-1).expand([-1 for _ in m.shape] +
[FlashAttention.MIN_BLOCK_SIZE])
expanded_grad_i = grad_i.unsqueeze(-1).expand([-1 for _ in grad_i.shape] +
[FlashAttention.MIN_BLOCK_SIZE])
if q_segment_ids is not None and kv_segment_ids is not None:
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids(
q_segment_ids, kv_segment_ids)
if require_grad_q:
payload, _ = trace_pallas(
_flash_attention_bwd_dq,
q,
k,
v,
ab,
segment_ids,
l,
m,
grad_output,
grad_i,
block_q_major=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"],
q.shape[2]),
block_k_major=min(
FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"], k.shape[2]),
block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"],
k.shape[2]),
sm_scale=sm_scale,
causal=causal,
mask_value=FlashAttention.DEFAULT_MASK_VALUE,
debug=False,
static_argnames=[
"block_q_major", "block_k_major", "block_k", "sm_scale", "causal",
"mask_value", "debug"
],
use_cache=True,
)
args = [q, k, v]
if ab is not None:
args += [ab]
if segment_ids is not None:
args += [q_segment_ids_fa, kv_segment_ids_fa]
args += [expanded_l, expanded_m, grad_output, expanded_grad_i]
outputs = [q]
if ab is not None:
outputs += [ab]
grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
[i.shape for i in outputs],
[i.dtype for i in outputs])
if require_grad_q:
grad_q = grads[0]
if require_grad_ab:
grad_ab = grads[1]
if require_grad_k or require_grad_v:
payload, _ = trace_pallas(
_flash_attention_bwd_dkv,
q,
k,
v,
ab,
segment_ids,
l,
m,
grad_output,
grad_i,
block_q_major=min(
FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"],
q.shape[2]),
block_k_major=min(
FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"],
k.shape[2]),
block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"],
k.shape[2]),
block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"],
q.shape[2]),
sm_scale=sm_scale,
causal=causal,
mask_value=FlashAttention.DEFAULT_MASK_VALUE,
debug=False,
static_argnames=[
"block_q_major", "block_k_major", "block_k", "block_q", "sm_scale",
"causal", "mask_value", "debug"
],
use_cache=True)
grads = torch_xla._XLAC._xla_tpu_custom_call(args, payload,
[k.shape, v.shape],
[k.dtype, v.dtype])
if require_grad_k:
grad_k = grads[0]
if require_grad_v:
grad_v = grads[1]
grad_q = undo_reshape(grad_q)
grad_k = undo_reshape(grad_k)
grad_v = undo_reshape(grad_v)
grad_ab = undo_reshape(grad_ab)
return grad_q, grad_k, grad_v, grad_ab
@custom_op("xla::fa_custom_backward", mutates_args=())
def fa_custom_backward(
grad_output: torch.Tensor, q: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, o: torch.Tensor, l: torch.Tensor, m: torch.Tensor,
q_segment_ids: Optional[torch.Tensor],
kv_segment_ids: Optional[torch.Tensor], ab: Optional[torch.Tensor],
causal: bool, sm_scale: float, partition_spec: str, mesh: str,
q_full_shape: List[int], kv_full_shape: List[int],
ab_full_shape: Optional[List[int]], ctx_grad: List[bool]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
partition_spec = eval(partition_spec)
mesh = Mesh.from_str(mesh) or xs.get_global_mesh()
grad_q = grad_k = grad_v = grad_ab = segment_ids = None
require_grad_q, require_grad_k, require_grad_v, *rest = ctx_grad
require_grad_ab = ctx_grad[-3]
q_full_shape = torch.Size(q_full_shape)
kv_full_shape = torch.Size(kv_full_shape)
ab_full_shape = torch.Size(
ab_full_shape) if ab_full_shape is not None else None
if partition_spec:
if len(partition_spec) == 5:
segment_id_partition_spec = (partition_spec[0], partition_spec[1],
partition_spec[3])
lm_partition_spec = partition_spec[:4]
else:
segment_id_partition_spec = (partition_spec[0], partition_spec[2])
lm_partition_spec = partition_spec[:3]
input_specs = [
partition_spec, # grad_output
partition_spec, # q
partition_spec, # k
partition_spec, # v
partition_spec, # o
lm_partition_spec, # l
lm_partition_spec, # m
segment_id_partition_spec, # q_segment_ids
segment_id_partition_spec, # kv_segment_ids
partition_spec, # ab
None, # causal
None, # sm_scale
None, # q_full_shape
None, # kv_full_shape
None, # ab_full_shape
None, # ctx_grad
]
output_specs = [
partition_spec,
partition_spec,
partition_spec,
partition_spec,
]
fa_backward_callable = _shard_map(_fa_custom_backward_single_device, mesh,
input_specs, output_specs)
else:
fa_backward_callable = _fa_custom_backward_single_device
res = fa_backward_callable(grad_output, q, k, v, o, l, m, q_segment_ids,
kv_segment_ids, ab, causal, sm_scale, q_full_shape,
kv_full_shape, ab_full_shape, ctx_grad)
return res
@fa_custom_forward.register_fake
def fa_custom_forward_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
causal: bool, q_segment_ids: Optional[torch.Tensor],
kv_segment_ids: Optional[torch.Tensor],
sm_scale: float, ab: Optional[torch.Tensor],
partition_spec: Optional[str], mesh: Optional[str],
ctx_grad: List[bool]):
assert q.shape == k.shape == v.shape
full_q = torch.empty_like(q)
full_k = torch.empty_like(k)
full_v = torch.empty_like(v)
full_ab = safe_empty_like(ab)
o = torch.empty_like(v)
l = torch.empty_like(full_q[:3])
m = torch.empty_like(full_q[:3])
return tuple(
[safe_empty_like(t) for t in (
o,
full_q,
full_k,
full_v,
l,
m,
full_ab,
)])
@fa_custom_backward.register_fake
def fa_custom_backward_fake(grad_output, q, k, v, o, l, m, q_segment_ids,
kv_segment_ids, ab, causal, sm_scale,
partition_spec, mesh, q_full_shape, kv_full_shape,
ab_full_shape, ctx_grad):
return tuple(safe_empty_like(t) for t in (q, k, v, ab))
class FlashAttention(torch.autograd.Function):
"""
This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
where we only takes q, k, v and causal as input and set block_sizes for the users.
"""
MIN_BLOCK_SIZE = 128
DEFAULT_MASK_VALUE = -0.7 * float(torch.finfo(torch.float32).max)
# The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240
# It yields much better performance than the default block_sizes.
DEFAULT_BLOCK_SIZES = {
"block_q": 512,
"block_k_major": 512,
"block_k": 512,
"block_b": 2,
"block_q_major_dkv": 512,
"block_k_major_dkv": 512,
"block_q_dkv": 512,
"block_k_dkv": 512,
"block_q_dq": 1024,
"block_k_dq": 256,
"block_k_major_dq": 512,
}
NUM_LANES = 128
NUM_SUBLANES = 8
@staticmethod
def prepare_segment_ids(
q_segment_ids,
kv_segment_ids) -> Tuple["SegmentIds", torch.Tensor, torch.Tensor]:
from jax.experimental.pallas.ops.tpu.flash_attention import SegmentIds
if q_segment_ids is None and kv_segment_ids is None:
return None, None, None
assert q_segment_ids is not None and kv_segment_ids is not None, "Both q_segment_ids and kv_segment_ids should be provided."
segment_ids = SegmentIds(
to_jax_shape_dtype_struct(q_segment_ids),
to_jax_shape_dtype_struct(kv_segment_ids))
q_segment_ids = q_segment_ids.unsqueeze(-1).expand(
[-1 for _ in q_segment_ids.shape] + [FlashAttention.NUM_LANES])
kv_segment_ids = kv_segment_ids.unsqueeze(1).expand([
kv_segment_ids.shape[0], FlashAttention.NUM_SUBLANES,
kv_segment_ids.shape[1]
])
return segment_ids, q_segment_ids, kv_segment_ids
@staticmethod
@requires_jax
def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
partition_spec, mesh):
ctx.q_shape = q.shape
ctx.k_shape = k.shape
ctx.causal = causal
ctx.sm_scale = sm_scale
ctx.partition_spec = partition_spec
ctx.mesh = mesh
ctx.q_full_shape = q.shape
ctx.kv_full_shape = k.shape
ctx.ab_full_shape = ab.shape if ab is not None else None
partition_spec = str(partition_spec)
mesh = str(mesh)
custom_op_arg = [
q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab,
partition_spec, mesh
]
ctx_grads = generate_ctx_need_grad(*custom_op_arg)
# AOT compatiable funtion only accepts argument types listed https://github.com/pytorch/pytorch/blob/82859f61857ef39898b34a5cdf0ae56ec25704d9/torch/_functorch/_aot_autograd/utils.py#L23-L34, so we serliaze partition_spec and mesh into string.
outs = fa_custom_forward(*custom_op_arg, ctx_grads)
o = outs[0]
full_q, full_k, full_v, l, m, full_ab = [x for x in outs[1:]]
# q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided
# but it should be OK as the backward will use the same partition_spec
ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids,
kv_segment_ids, full_ab)
return o
@staticmethod
@requires_jax
def backward(ctx, grad_output):
q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors
causal = ctx.causal
sm_scale = ctx.sm_scale
partition_spec = ctx.partition_spec
mesh = ctx.mesh
q_full_shape = ctx.q_full_shape
kv_full_shape = ctx.kv_full_shape
ab_full_shape = ctx.ab_full_shape
grad_output, q, k, v, o, l, m = [
t.contiguous() for t in (grad_output, q, k, v, o, l, m)
]
# this segment_ids only reflects the local shape of segment_ids
custom_op_arg = [
grad_output, q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab,
causal, sm_scale,
str(partition_spec),
str(mesh), q_full_shape, kv_full_shape, ab_full_shape
]
ctx_grads = ctx.needs_input_grad
grad_q, grad_k, grad_v, grad_ab = fa_custom_backward(
*custom_op_arg, ctx_grads)
return grad_q, grad_k, grad_v, None, None, None, None, grad_ab, None, None
def flash_attention(
q, # [batch_size, num_heads, q_seq_len, d_model]
k, # [batch_size, num_heads, kv_seq_len, d_model]
v, # [batch_size, num_heads, kv_seq_len, d_model]
causal=False,
q_segment_ids=None, # [batch_size, q_seq_len]
kv_segment_ids=None, # [batch_size, kv_seq_len]
sm_scale=1.0,
*,
ab=None, # [batch_size, num_heads, q_seq_len, kv_seq_len]
partition_spec=None,
mesh=None,
):
# TODO: support SPMD and Dynamo with segment_ids.
return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids,
sm_scale, ab, partition_spec, mesh)
def _ragged_paged_attention_nonkernel(
queries, # [max_num_batched_tokens, num_q_heads, head_dim]
kv_pages, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
kv_lens, # i32[max_num_seqs]
page_indices, # i32[max_num_seqs, pages_per_seq]
cu_q_lens, # i32[max_num_seqs + 1]
num_seqs, # i32
*,
sm_scale=1.0,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value=DEFAULT_MASK_VALUE,
):
_, _, num_combined_kv_heads, head_dim = kv_pages.shape
assert num_combined_kv_heads % 2 == 0
num_kv_heads = num_combined_kv_heads // 2
num_q_heads = queries.shape[1]
assert num_q_heads % num_kv_heads == 0
num_query_per_kv = num_q_heads // num_kv_heads
outputs = []
for i in range(num_seqs):
q_start = cu_q_lens[i]
q_end = cu_q_lens[i + 1]
q_len = q_end - q_start
kv_len = kv_lens[i]
indices = page_indices[i]
q = queries[q_start:q_end]
k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads,
head_dim)[:kv_len]
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads,
head_dim)[:kv_len]
k = torch.repeat_interleave(k, num_query_per_kv, dim=1)
v = torch.repeat_interleave(v, num_query_per_kv, dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k)
attn *= sm_scale
empty_mask = torch.ones(q_len, kv_len, device=attn.device)
mask = torch.triu(empty_mask, diagonal=kv_len - q_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(
empty_mask,
diagonal=kv_len - (q_len + sliding_window) + 1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, mask_value)
attn = torch.softmax(
attn, dim=-1).to(v.dtype) # [num_query_heads, cur_q_len, kv_len]
out = torch.einsum("hqk,khd->qhd", attn,
v) # [cur_q_len, num_query_heads, head_dim]
outputs.append(out)
return torch.cat(outputs, dim=0)
def _get_default_ragged_paged_attention_block_size(token_num):
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
if tpu_version == 4:
# This default block size is not tuned, only make sure there's no
# OOM in vmem
num_kv_pages_per_block = 16
num_queries_per_block = 128
return num_kv_pages_per_block, num_queries_per_block
# This heristic is based on the initial kernel micro benchmarking:
# When the token_num is small, there's no long request of prefill.
# While when it's larger, the block size is adjusted for it.
if token_num <= 128:
num_kv_pages_per_block = 128
num_queries_per_block = 32
else:
num_kv_pages_per_block = 128
num_queries_per_block = 96
return num_kv_pages_per_block, num_queries_per_block
@requires_jax
def ragged_paged_attention(
q, # [max_num_batched_tokens, num_q_heads, head_dim]
kv_pages, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
kv_lens, # i32[max_num_seqs]
page_indices, # i32[max_num_seqs, pages_per_seq]
cu_q_lens, # i32[max_num_seqs + 1]
num_seqs, # i32[1]
*,
sm_scale=1.0,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value=None,
use_kernel=True,
# kernel tuning parameters
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
):
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
if not use_kernel:
return _ragged_paged_attention_nonkernel(
q,
kv_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs.item(),
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
mask_value=mask_value,
)
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_attention
if num_kv_pages_per_block is None:
assert num_queries_per_block is None
token_num = q.shape[0]
num_kv_pages_per_block, num_queries_per_block = _get_default_ragged_paged_attention_block_size(
token_num)
if vmem_limit_bytes is None:
vmem_limit_bytes = 64 * 1024 * 1024
payload, _ = trace_pallas(
ragged_attention,
q,
kv_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,