-
Notifications
You must be signed in to change notification settings - Fork 505
/
Copy pathtest_multi_queries_paged_attention_kernel.py
377 lines (338 loc) · 12.1 KB
/
test_multi_queries_paged_attention_kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from torch_xla.experimental.pallas_kernels.multi_queries_paged_attention_kernel import paged_attention
import jax.numpy as jnp
import numpy as np
jax.config.parse_flags_with_absl()
# Set up paged_attention inputs.
def _generate_qkv(
batch_size,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_q_heads,
head_dim,
prng_key,
dtype,
):
assert max_kv_len % page_size == 0
pages_per_sequence = max_kv_len // page_size
total_pages = batch_size * pages_per_sequence
k1, k2, k3, k4 = jax.random.split(prng_key, 4)
k_pages = jax.random.normal(
k1, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype)
v_pages = jax.random.normal(
k2, (num_kv_heads, total_pages, page_size, head_dim), dtype=dtype)
page_indices = jnp.arange(batch_size * pages_per_sequence, dtype=jnp.int32)
page_indices = jax.random.permutation(k3, page_indices, independent=True)
page_indices = page_indices.reshape(batch_size, pages_per_sequence)
q = jax.random.normal(
k4, (batch_size, query_len, num_q_heads, head_dim), dtype=dtype)
return q, k_pages, v_pages, page_indices
def _ref_jax_extended_paged_attention(
q, # [batch_size, query_len, num_query_heads, head_size]
k_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
v_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
lengths, # [batch_size], the effective kv_length.
page_indices, # [batch_size, pages_per_sequence]
effective_q_lens, # [batch_size] the effective q_length
attn_logits_soft_cap: float | None = None,
):
batch_size, query_len, num_query_heads, head_size = q.shape
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
num_query_per_kv = num_query_heads // num_kv_heads
outputs = []
for i in range(batch_size):
kv_len = lengths[i]
num_pages = (kv_len + page_size - 1) // page_size
indices = page_indices[i, :num_pages]
k = k_pages[:, indices]
k = jnp.permute_dims(k, (1, 2, 0, 3))
k = jnp.reshape(k, (num_pages * page_size, num_kv_heads, head_size))
k = k[:kv_len]
v = v_pages[:, indices]
v = jnp.permute_dims(v, (1, 2, 0, 3))
v = jnp.reshape(v, (num_pages * page_size, num_kv_heads, head_size))
v = v[:kv_len]
if num_query_per_kv != 1:
k = jnp.repeat(k, num_query_per_kv, axis=1)
v = jnp.repeat(v, num_query_per_kv, axis=1)
attn = jnp.einsum("qhd,khd->hqk", q[i], k)
if attn_logits_soft_cap is not None:
capped_attn = jnp.tanh(attn / attn_logits_soft_cap)
attn = capped_attn * attn_logits_soft_cap
attn = attn.astype('float32')
effective_q_len = effective_q_lens[i]
q_span = (kv_len - effective_q_len) + jax.lax.broadcasted_iota(
jnp.int32, (query_len, kv_len), 0)
kv_span = jax.lax.broadcasted_iota(jnp.int32, (query_len, kv_len), 1)
mask = jnp.where(q_span < kv_span, float("-inf"), 0.)
with jax.numpy_rank_promotion("allow"):
attn = attn + mask
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
output = jnp.stack(outputs, axis=0)
return output
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class PagedAttentionKernelTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
# def test_paged_attention(
# self,
# ):
# dtype = jnp.bfloat16
# page_size=16
# num_kv_heads = 8
# q_kv_head_ratio = 4
# head_dim = 256
# num_queries_per_compute_block = 32
# block_kv_size = 256
@parameterized.product(
dtype=(jnp.float32, jnp.bfloat16),
page_size=(16, 32, 64),
num_kv_heads=(1, 8),
q_kv_head_ratio=(1, 4, 8),
head_dim=(128, 256),
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
attn_logits_soft_cap=(1.0, None),
)
def test_paged_attention_without_query_padding(
self,
dtype,
page_size,
num_kv_heads,
q_kv_head_ratio,
head_dim,
num_queries_per_compute_block,
block_kv_size,
attn_logits_soft_cap,
):
max_kv_len = 2048
query_len = 64
batch_size = 3
kv_seq_lens = jax.random.randint(
jax.random.key(0), (batch_size,), query_len, max_kv_len)
assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size
q, k_pages, v_pages, page_indices = _generate_qkv(
batch_size,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
jax.random.key(0),
dtype,
)
print(f'Running paged_attention with {query_len=}')
num_kv_pages_per_compute_block = block_kv_size // page_size
effective_q_lens = jnp.full_like(kv_seq_lens, query_len)
actual_output = paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
attn_logits_soft_cap=attn_logits_soft_cap,
)
# Note kernel execution is async. Without blocking, if an error happens in the kernel, the error may point to some irrelevant and confusing places. See https://github.com/pytorch/xla/pull/8356#issuecomment-2486861631
actual_output = jax.block_until_ready(actual_output)
# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
attn_logits_soft_cap=attn_logits_soft_cap,
)
self.assertEqual(actual_output.shape, expected_output.shape)
if dtype == jnp.float32:
atol = 1e-1
rtol = 1e-2
elif dtype == jnp.bfloat16:
atol = 6e-1
rtol = 1e-1
else:
self.fail(f'Unsupported dtype: {dtype}')
self.assertTrue(
jnp.allclose(expected_output, actual_output, atol=atol, rtol=rtol))
# def test_paged_attention_query_len_longer_than_kv_seq_len(
# self,
# ):
# dtype = jnp.float32
# page_size=16
# num_kv_heads = 8
# q_kv_head_ratio = 4
# head_dim = 256
# num_queries_per_compute_block = 32
# block_kv_size = 256
# In practice, vLLM would pad the query so that the query seq len will be longer than the kv seq len. query seq len may be padded but not for kv seq len.
# When this happens, we need an additional parameter `effective_q_lens` to the paged_attention to set the causal mask right.
@parameterized.product(
dtype=(jnp.float32, jnp.bfloat16),
page_size=(16, 32, 64),
num_kv_heads=(1, 8),
q_kv_head_ratio=(1, 4, 8),
head_dim=(128, 256),
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
)
def test_paged_attention_with_query_padding(
self,
dtype,
page_size,
num_kv_heads,
q_kv_head_ratio,
head_dim,
num_queries_per_compute_block,
block_kv_size,
):
max_kv_len = 512
# Set query_len>kv_seq_lens
query_len = max_kv_len
batch_size = 3
kv_seq_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, max_kv_len)
effective_q_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, kv_seq_lens)
for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens):
assert cur_effec_q_len <= cur_kv_seq_len, f'The effective query len {cur_effec_q_len} should be less than or equal to the kv_len {cur_kv_seq_len} in the current sequence.'
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size
q, k_pages, v_pages, page_indices = _generate_qkv(
batch_size,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
jax.random.key(0),
dtype,
)
print(
f'Running paged_attention with {query_len=}, {kv_seq_lens=}, {effective_q_lens=}'
)
num_kv_pages_per_compute_block = block_kv_size // page_size
actual_output = paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
)
actual_output = jax.block_until_ready(actual_output)
# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
)
self.assertEqual(actual_output.shape, expected_output.shape)
if dtype == jnp.float32:
atol = 2e-2
rtol = 1e-2
elif dtype == jnp.bfloat16:
atol = 6e-1
rtol = 1e-1
else:
self.fail(f'Unsupported dtype: {dtype}')
for b in range(batch_size):
# N.B. For the output ([batch_size, query_len, num_q_heads, head_dim]) at query_len dim, all the value after the effective_q_len will be thrown away due to we padding the query seq len. The values after the effective_q_len may differ between the kernel and the ref impl because of the causal mask.
effective_q_len = effective_q_lens[b]
self.assertTrue(
jnp.allclose(
expected_output[b, :effective_q_len],
actual_output[b, :effective_q_len],
atol=atol,
rtol=rtol))
def test_paged_attention_store_to_output_correctly(self,):
# Make sure the internal FA store_to_output correctly.
dtype = jnp.float32
page_size = 16
num_kv_heads = 8
q_kv_head_ratio = 4
head_dim = 256
num_queries_per_compute_block = 32
block_kv_size = 256
max_kv_len = 512
query_len = max_kv_len
batch_size = 3
# Set various edge case testing the internal flash attention can store_to_output correct
kv_seq_lens = jnp.array(
[block_kv_size - 1, block_kv_size + 1, 2 * block_kv_size])
assert len(kv_seq_lens) == batch_size
effective_q_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, kv_seq_lens)
for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens):
assert cur_effec_q_len <= cur_kv_seq_len, f'The effective query len {cur_effec_q_len} should be less than or equal to the kv_len {cur_kv_seq_len} in the current sequence.'
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size
q, k_pages, v_pages, page_indices = _generate_qkv(
batch_size,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
jax.random.key(0),
dtype,
)
num_kv_pages_per_compute_block = block_kv_size // page_size
actual_output = paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)
actual_output = jax.block_until_ready(actual_output)
# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
)
self.assertEqual(actual_output.shape, expected_output.shape)
atol = 2e-2
rtol = 1e-2
for b in range(batch_size):
effective_q_len = effective_q_lens[b]
self.assertTrue(
jnp.allclose(
expected_output[b, :effective_q_len],
actual_output[b, :effective_q_len],
atol=atol,
rtol=rtol))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())