Flowpipe 3
Flowpipe 3
Abstract
Scaling Transformers to longer sequence lengths has been a major problem in the last several years,
promising to improve performance in language modeling and high-resolution image understanding, as
well as to unlock new applications in code, audio, and video generation. The attention layer is the
main bottleneck in scaling to longer sequences, as its runtime and memory increase quadratically in
the sequence length. FlashAttention [5] exploits the asymmetric GPU memory hierarchy to bring
significant memory saving (linear instead of quadratic) and runtime speedup (2-4× compared to optimized
baselines), with no approximation. However, FlashAttention is still not nearly as fast as optimized
matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s. We
observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and
warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes. We propose
FlashAttention-2, with better work partitioning to address these issues. In particular, we (1) tweak
the algorithm to reduce the number of non-matmul FLOPs (2) parallelize the attention computation, even
for a single head, across different thread blocks to increase occupancy, and (3) within each thread block,
distribute the work between warps to reduce communication through shared memory. These yield around
2× speedup compared to FlashAttention, reaching 50-73% of the theoretical maximum FLOPs/s on
A100 and getting close to the efficiency of GEMM operations. We empirically validate that when used
end-to-end to train GPT-style models, FlashAttention-2 reaches training speed of up to 225 TFLOPs/s
per A100 GPU (72% model FLOPs utilization).1
1 Introduction
Scaling up the context length of Transformers [18] is a challenge, since the attention layer at their heart
has runtime and memory requirements quadratic in the input sequence length. Ideally, we would like to go
beyond the standard 2k sequence length limit to train models to understand books, high resolution images,
and long-form videos. Just within the last year, there have been several language models with much longer
context than before: GPT-4 [12] with context length 32k, MosaicML’s MPT with context length 65k, and
Anthropic’s Claude with context length 100k. Emerging use cases such as long document querying and story
writing have demonstrated a need for models with such long context.
To reduce the computational requirement of attention on such long context, there have been numerous
methods proposed to approximate attention [2, 3, 4, 8, 9, 14, 19, 20]. Though these methods have seen
some use cases, as far as we know, most large-scale training runs still use standard attention. Motivated by
this, Dao et al. [5] proposed to reorder the attention computation and leverages classical techniques (tiling,
recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence
length. This yields 2-4× wall-clock time speedup over optimized baselines, up to 10-20× memory saving,
1 FlashAttention-2 is available at https://github.com/Dao-AILab/flash-attention
1
with no approximation, and as a result FlashAttention has seen wide adoption in large-scale training and
inference of Transformers.
However, context length increases even more, FlashAttention is still not nearly as efficient as other
primitives such as matrix-multiply (GEMM). In particular, while FlashAttention is already 2-4× faster
than a standard attention implementation, the forward pass only reaches 30-50% of the theoretical maximum
FLOPs/s of the device (Fig. 5), while the backward pass is even more challenging, reaching only 25-35%
of maximum throughput on A100 GPU (Fig. 6). In contrast, optimized GEMM can reach up to 80-90% of
the theoretical maximum device throughput. Through careful profiling, we observe that FlashAttention
still has suboptimal work partitioning between different thread blocks and warps on the GPU, causing either
low-occupancy or unnecessary shared memory reads/writes.
Building on FlashAttention, we propose FlashAttention-2 with better parallelism and work
partitioning to address these challenges.
1. In Section 3.1, we tweak the algorithms to reduce the number of non-matmul FLOPs while not changing
the output. While the non-matmul FLOPs only account for a small fraction of the total FLOPs, they
take longer to perform as GPUs have specialized units for matrix multiply, and as a result the matmul
throughput can be up to 16× higher than non-matmul throughput. It is thus important to reduce
non-matmul FLOPs and spend as much time as possible doing matmul FLOPs.
2. We propose to parallelize both the forward pass and backward pass along the sequence length dimension,
in addition to the batch and number of heads dimension. This increases occupancy (utilization of GPU
resources) in the case where the sequences are long (and hence batch size is often small).
3. Even within one block of attention computation, we partition the work between different warps of a
thread block to reduce communication and shared memory reads/writes.
2 Background
We provide some background on the performance characteristics and execution model of GPUs. We also
describe the standard implementation of attention, as well as FlashAttention.
2
2.2 Standard Attention Implementation
Given input sequences Q, K, V ∈ R 𝑁 ×𝑑 where 𝑁 is the sequence length and 𝑑 is the head dimension, we want
to compute the attention output O ∈ R 𝑁 ×𝑑 :
S = QK> ∈ R 𝑁 ×𝑁 , P = softmax(S) ∈ R 𝑁 ×𝑁 , O = PV ∈ R 𝑁 ×𝑑 ,
where softmax is applied row-wise.2 For multi-head attention (MHA), this same computation is performed in
parallel across many heads, and parallel over the batch dimension (number of input sequences in a batch).
The backward pass of attention proceeds as follows. Let dO ∈ R 𝑁 ×𝑑 be the gradient of O with respect to
some loss function. Then by the chain rule (aka backpropagation):
dV = P> dO ∈ R 𝑁 ×𝑑
dP = dOV> ∈ R 𝑁 ×𝑁
dS = dsoftmax(dP) ∈ R 𝑁 ×𝑁
dQ = dSK ∈ R 𝑁 ×𝑑
dK = QdS> ∈ R 𝑁 ×𝑑 ,
where dsoftmax is the gradient (backward pass) of softmax applied row-wise. One can work out that if 𝑝 =
softmax(𝑠) for some vector 𝑠 and 𝑝, then with output gradient 𝑑𝑝, the input gradient 𝑑𝑠 = (diag( 𝑝) − 𝑝 𝑝 > )𝑑𝑝.
Standard attention implementations materialize the matrices S and P to HBM, which takes 𝑂 (𝑁 2 )
memory. Often 𝑁 𝑑 (typically 𝑁 is on the order of 1k–8k and 𝑑 is around 64–128). The standard attention
implementation (1) calls the matrix multiply (GEMM) subroutine to multiply S = QK> , writes the result to
HBM, then (2) loads § from HBM to compute softmax and write the result P to HBM, and finally (3) calls
GEMM to get O = PV. As most of the operations are bounded by memory bandwidth, the large number of
memory accesses translates to slow wall-clock time. Moreover, the required memory is 𝑂 (𝑁 2 ) due to having
to materialize S and P. Moreover, one has to save P ∈ R 𝑁 ×𝑁 for the backward pass to compute the gradients.
2.3 FlashAttention
To speed up attention on hardware accelerators such as GPU, [5] proposes an algorithm to reduce the memory
reads/writes while maintaining the same output (without approximation).
dropout applied to P
3
compute:
𝑚 = max(rowmax(S ( 1) ), rowmax(S ( 2) )) ∈ R 𝐵𝑟
( 1) ( 2)
ℓ = rowsum(𝑒 S −𝑚 ) + rowsum(𝑒 S −𝑚 ) ∈ R 𝐵𝑟
h i
( 1) ( 2)
P = P ( 1) P ( 2) = diag(ℓ) −1 𝑒 S −𝑚 𝑒 S −𝑚 ∈ R 𝐵𝑟 ×2 𝐵𝑐
V ( 1)
( 1) ( 2)
O = P ( 1) P ( 2) = diag(ℓ) −1 𝑒 S −𝑚 V ( 1) + 𝑒 S −𝑚 V ( 2) ∈ R 𝐵𝑟 ×𝑑 .
V ( 2)
Online softmax instead computes “local” softmax with respect to each block and rescale to get the right
output at the end:
𝑚 ( 1) = rowmax(S ( 1) ) ∈ R 𝐵𝑟
( 1 ) −𝑚 ( 1 )
ℓ ( 1) = rowsum(𝑒 S ) ∈ R 𝐵𝑟
( 1) −𝑚 ( 1)
P̃ ( 1) = diag(ℓ ( 1) ) −1 𝑒 S ∈ R 𝐵𝑟 ×𝐵𝑐
( 1 ) −𝑚 ( 1 )
O ( 1) = P̃ ( 1) V ( 1) = diag(ℓ ( 1) ) −1 𝑒 S V ( 1) ∈ R 𝐵𝑟 ×𝑑
𝑚 ( 2) = max(𝑚 ( 1) , rowmax(S ( 2) )) = 𝑚
( 1 ) −𝑚 ( 2) ( 2) −𝑚 ( 2) ( 1 ) −𝑚 ( 2) −𝑚
ℓ ( 2) = 𝑒 𝑚 ℓ ( 1) + rowsum(𝑒 S ) = rowsum(𝑒 S ) + rowsum(𝑒 S )=ℓ
S ( 2) −𝑚 ( 2)
P̃ ( 2) = diag(ℓ ( 2) ) −1 𝑒
( 1 ) −𝑚 ( 2 ) −𝑚
O ( 2) = diag(ℓ ( 1) /ℓ ( 2) ) −1 O ( 1) + P̃ ( 2) V ( 2) = diag(ℓ ( 2) ) −1 𝑒 𝑠 V ( 1) + diag(ℓ ( 2) ) −1 𝑒 𝑠 V ( 2) = O.
We show how FlashAttention uses online softmax to enable tiling (Fig. 1) to reduce memory reads/writes.
Figure 1: Diagram of how FlashAttention forward pass is performed, when the key K is partitioned into
two blocks and the value V is also partitioned into two blocks. By computing attention with respect to
each block and rescaling the output, we get the right answer at the end, while avoiding expensive memory
reads/writes of the intermediate matrices S and P. We simplify the diagram, omitting the step in softmax
that subtracts each element by the row-wise max.
4
2.3.2 Backward pass
In the backward pass, by re-computing the values of the attention matrices S and P once blocks of inputs
Q, K, V are already loaded to SRAM, FlashAttention avoids having to store large intermediate values. By
not having to save the large matrices S and P of size 𝑁 × 𝑁, FlashAttention yields 10-20× memory saving
depending on sequence length (memory required in linear in sequence length 𝑁 instead of quadratic). The
backward pass also achieves 2-4× wall-clock speedup due to reduce memory reads/writes.
The backward pass applies tiling to the equations in Section 2.2. Though the backward pass is simpler
than the forward pass conceptually (there is no softmax rescaling), the implementation is significantly more
involved. This is because there are more values to be kept in SRAM to perform 5 matrix multiples in the
backward pass, compared to just 2 matrix multiples in the forward pass.
3.1 Algorithm
We tweak the algorithm from FlashAttention to reduce the number of non-matmul FLOPs. This is
because modern GPUs have specialized compute units (e.g., Tensor Cores on Nvidia GPUs) that makes
matmul much faster. As an example, the A100 GPU has a max theoretical throughput of 312 TFLOPs/s of
FP16/BF16 matmul, but only 19.5 TFLOPs/s of non-matmul FP32. Another way to think about this is that
each non-matmul FLOP is 16× more expensive than a matmul FLOP. To maintain high throughput (e.g.,
more than 50% of the maximum theoretical TFLOPs/s), we want to spend as much time on matmul FLOPs
as possible.
We can instead maintain an “un-scaled” version of O ( 2) and keep around the statistics ℓ ( 2) :
( 2 ) −𝑚 ( 2)
Õ ( 2) = diag(ℓ ( 1) ) −1 O ( 1) + 𝑒 S V ( 2) .
Only at the every end of the loop do we scale the final Õ ( last) by diag(ℓ ( last) ) −1 to get the right output.
2. We do not have to save both the max 𝑚 ( 𝑗) and the sum of exponentials ℓ ( 𝑗) for the backward pass. We
only need to store the logsumexp 𝐿 ( 𝑗) = 𝑚 ( 𝑗) + log(ℓ ( 𝑗) ).
5
In the simple case of 2 blocks in Section 2.3, the online softmax trick now becomes:
𝑚 ( 1) = rowmax(S ( 1) ) ∈ R 𝐵𝑟
( 1) ( 1)
ℓ ( 1) = rowsum(𝑒 S −𝑚 ) ∈ R 𝐵𝑟
O˜( 1) = 𝑒 S −𝑚 V ( 1) ∈ R 𝐵𝑟 ×𝑑
( 1) ( 1)
𝑚 ( 2) = max(𝑚 ( 1) , rowmax(S ( 2) )) = 𝑚
( 1) −𝑚 ( 2) ( 2) −𝑚 ( 2 ) ( 1 ) −𝑚 ( 2) −𝑚
ℓ ( 2) = 𝑒 𝑚 ℓ ( 1) + rowsum(𝑒 S ) = rowsum(𝑒 S ) + rowsum(𝑒 S )=ℓ
( 2) ( 2) −1 S ( 2) −𝑚 ( 2)
P̃ = diag(ℓ ) 𝑒
𝑚 ( 1) −𝑚 ( 2) ( 2 ) −𝑚 ( 2 ) ( 1) −𝑚 ( 2) −𝑚
Õ ( 2) = diag(𝑒 ) Õ ( 1) + 𝑒 S V ( 2) = 𝑒 𝑠 V ( 1) + 𝑒 𝑠 V ( 2)
O ( 2) = diag(ℓ ( 2) ) −1 Õ ( 2) = O.
Causal masking. One common use case of attention is in auto-regressive language modeling, where we
need to apply a causal mask to the attention matrix S (i.e., any entry S𝑖 𝑗 with 𝑗 > 𝑖 is set to −∞).
1. As FlashAttention and FlashAttention-2 already operate by blocks, for any blocks where all
the column indices are more than the row indices (approximately half of the blocks for large sequence
length), we can skip the computation of that block. This leads to around 1.7-1.8× speedup compared
to attention without the causal mask.
2. We do not need to apply the causal mask for blocks whose row indices are guaranteed to be strictly less
than the column indices. This means that for each row, we only need apply causal mask to 1 block
(assuming square block).
6
Correctness, runtime, and memory requirement. As with FlashAttention, Algorithm 1 returns
the correct output O = softmax(QK> )V (with no approximation), using 𝑂 (𝑁 2 𝑑) FLOPs and requires 𝑂 (𝑁)
additional memory beyond inputs and output (to store the logsumexp 𝐿). The proof is almost the same as
the proof of Dao et al. [5, Theorem 1], so we omit it here.
Multi-query attention and grouped-query attention. Multi-query attention (MQA) [15] and grouped-
query attention (GQA) [1] are variants of attention where multiple heads of query attend to the same head of
key and value, in order to reduce the size of KV cache during inference. Instead of having to duplicate the
key and value heads for the computation, we implicitly manipulate the indices into the head to perform the
same computation. In the backward pass, we need to sum the gradients dK and dV across different heads
that were implicitly duplicated.
3.2 Parallelism
The first version of FlashAttention parallelizes over batch size and number of heads. We use 1 thread
block to process one attention head, and there are overall batch size · number of heads thread blocks. Each
thread block is scheduled to run on a streaming multiprocessor (SM), and there are 108 of these SMs on
7
an A100 GPU for example. This scheduling is efficient when this number is large (say ≥ 80), since we can
effectively use almost all of the compute resources on the GPU.
In the case of long sequences (which usually means small batch sizes or small number of heads), to make
better use of the multiprocessors on the GPU, we now additionally parallelize over the sequence length
dimension. This results in significant speedup for this regime.
Forward pass. We see that the outer loop (over sequence length) is embarrassingly parallel, and we
schedule them on different thread blocks that do not need to communicate with each other. We also parallelize
over the batch dimension and number of heads dimension, as done in FlashAttention. The increased
parallelism over sequence length helps improve occupancy (fraction of GPU resources being used) when the
batch size and number of heads are small, leading to speedup in this case.
These ideas of swapping the order of the loop (outer loop over row blocks and inner loop over column
blocks, instead of the other way round in the original FlashAttention paper), as well as parallelizing
over the sequence length dimension were first suggested and implemented by Phil Tillet in the Triton [17]
implementation.3
Backward pass. Notice that the only shared computation between different column blocks is in update dQ
( 𝑗)
in Algorithm 2, where we need to load dQ𝑖 from HBM to SRAM, then on chip, update dQ𝑖 ← dQ𝑖 + dS𝑖 K 𝑗 ,
and write back to HBM. We thus parallelize over the sequence length dimension as well, and schedule 1
thread block for each column block of the backward pass. We use atomic adds to communicate between
different thread blocks to update dQ.
We describe the parallelization scheme in Fig. 2.
Figure 2: In the forward pass (left), we parallelize the workers (thread blocks) where each worker takes care
of a block of rows of the attention matrix. In the backward pass (right), each worker takes care of a block of
columns of the attention matrix.
3 https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
8
3.3 Work Partitioning Between Warps
As Section 3.2 describe how we schedule thread blocks, even within each thread block, we also have to decide
how to partition the work between different warps. We typically use 4 or 8 warps per thread block, and the
partitioning is described in Fig. 3.
Forward pass. For each block, FlashAttention splits K and V across 4 warps while keeping Q accessible
by all warps. Each warp multiplies to get a slice of QK> , then they need to multiply with a slice of V and
communicate to add up the result. This is referred to as the “split-K” scheme. However, this is inefficient
since all warps need to write their intermediate results out to shared memory, synchronize, then add up the
intermediate results. These shared memory reads/writes slow down the forward pass in FlashAttention.
In FlashAttention-2, we instead split Q across 4 warps while keeping K and V accessible by all warps.
After each warp performs matrix multiply to get a slice of QK> , they just need to multiply with their shared
slice of V to get their corresponding slice of the output. There is no need for communication between warps.
The reduction in shared memory reads/writes yields speedup (Section 4).
Backward pass. Similarly for the backward pass, we choose to partition the warps to avoid the “split-K”
scheme. However, it still requires some synchronization due to the more complicated dependency between all
the different inputs and gradients Q, K, V, O, dO, dQ, dK, dV. Nevertheless, avoiding “split-K” reduces shared
memory reads/writes and again yields speedup (Section 4).
Tuning block sizes Increasing block sizes generally reduces shared memory loads/stores, but increases
the number of registers required and the total amount of shared memory. Past a certain block size, register
spilling causes significant slowdown, or the amount of shared memory required is larger than what the GPU
has available, and the kernel cannot run at all. Typically we choose blocks of size {64, 128} × {64, 128},
depending on the head dimension 𝑑 and the device shared memory size.
We manually tune for each head dimensions since there are essentially only 4 choices for block sizes, but
this could benefit from auto-tuning to avoid this manual labor. We leave this to future work.
4 Empirical Validation
We evaluate the impact of using FlashAttention-2 to train Transformer models.
• Benchmarking attention. We measure the runtime of FlashAttention-2 across different sequence
lengths and compare it to a standard implementation in PyTorch, FlashAttention, and FlashAttention
in Triton. We confirm that FlashAttention-2 is 1.7-3.0× faster than FlashAttention, 1.3-2.5×
faster than FlashAttention in Triton, and 3-10× faster than a standard attention implementation.
9
FlashAttention-2 reaches up to 230 TFLOPs/s, 73% of the theoretical maximum TFLOPs/s on A100
GPUs.
• End-to-end training speed When used end-to-end to train GPT-style models of size 1.3B and 2.7B on
sequence lengths either 2k or 8k, FlashAttention-2 yields up to 1.3× speedup compared to FlashAt-
tention and 2.8× speedup compared to a baseline without FlashAttention. FlashAttention-2
reaches up to 225 TFLOPs/s (72% model FLOPs utilization) per A100 GPU.
With causal mask, we divide this number by 2 to account for the fact that approximately only half of the
entries are calculated. To get the FLOPs of the backward pass, we multiply the forward pass FLOPs by 2.5
(since there are 2 matmuls in the forward pass and 5 matmuls in the backward pass, due to recomputation).
Attention forward + backward speed (A100 80GB SXM4) Attention forward + backward speed (A100 80GB SXM4)
Pytorch Pytorch
FlashAttention FlashAttention 196 201 203
200 xformers 200 xformers 187
175 176
FlashAttention Triton 171 FlashAttention 173
Triton
Speed (TFLOPs/s)
Speed (TFLOPs/s)
162
FlashAttention-2
153 FlashAttention-2
151
150 132
150
OOM OOM
512 1k 2k 4k 8k 16k 512 1k 2k 4k 8k 16k
Sequence length Sequence length
(a) Without causal mask, head dimension 64 (b) Without causal mask, head dimension 128
Attention forward + backward speed (A100 80GB SXM4) Attention forward + backward speed (A100 80GB SXM4)
Pytorch Pytorch
FlashAttention FlashAttention
200 xformers 200 xformers 182 189
FlashAttention Triton 171 FlashAttention Triton 173
Speed (TFLOPs/s)
Speed (TFLOPs/s)
165
FlashAttention-2 156 FlashAttention-2 155
150 140 150 133
119
97 99
100 88 87 92 100 82 87 91 92
83 80
77 79 76 79 80 76 74 80 78
70 75 66 68 69 67 72 69 68
585159 60 555850 62 61
50 50 32 32 34
23 28
15 16 17 18 18
OOM OOM
512 1k 2k 4k 8k 16k 512 1k 2k 4k 8k 16k
Sequence length Sequence length
(c) With causal mask, head dimension 64 (d) With causal mask, head dimension 128
10
Attention forward speed (A100 80GB SXM4) Attention forward speed (A100 80GB SXM4)
224 227 222 224 223
Pytorch Pytorch
209
FlashAttention FlashAttention
200 xformers 191 193 192 192 192 200 xformers
178
FlashAttention Triton FlashAttention Triton
Speed (TFLOPs/s)
Speed (TFLOPs/s)
157 160 163
FlashAttention-2 149 152 152 155 FlashAttention-2 152
150 141 150 140
128 127 122 122 122
115 120
10499 10498 10498 107
9694 9997
100 9189 100
69 66 71 71 6772 73
60 63
56
50 34 35 37 37 50 42
29
OOM OOM
512 1k 2k 4k 8k 16k 512 1k 2k 4k 8k 16k
Sequence length Sequence length
(a) Without causal mask, head dimension 64 (b) Without causal mask, head dimension 128
Attention forward speed (A100 80GB SXM4) Attention forward speed (A100 80GB SXM4)
Pytorch Pytorch
FlashAttention FlashAttention 198 200 197
200 xformers 183
200 xformers 187
177 181
FlashAttention Triton FlashAttention 168
Triton
Speed (TFLOPs/s)
Speed (TFLOPs/s)
167
FlashAttention-2
146 FlashAttention-2 148
150 137 143 150 132 133
141
131 126
115 112 112 115 117
108 107
99 9495 95
100 82 89 8992 9194 100 89
78 81 79
71 70 65 68 70 71
56 59
49
50 50
15 18 19 19 19
10 10 10 10 10
OOM OOM
512 1k 2k 4k 8k 16k 512 1k 2k 4k 8k 16k
Sequence length Sequence length
(c) With causal mask, head dimension 64 (d) With causal mask, head dimension 128
Just running the same implementation on H100 GPUs (using no special instructions to make use of new
features such as TMA and 4th-gen Tensor Cores), we obtain up to 335 TFLOPs/s (Fig. 7). We expect that by
using new instructions, we can obtain another 1.5x-2x speedup on H100 GPUs. We leave that to future work.
11
Attention backward speed (A100 80GB SXM4) Attention backward speed (A100 80GB SXM4)
Pytorch Pytorch
FlashAttention FlashAttention 196
200 xformers 200 xformers 187 193
175
FlashAttention Triton 169 170 FlashAttention 159
Triton
Speed (TFLOPs/s)
Speed (TFLOPs/s)
163
FlashAttention-2 152 FlashAttention-2
150 141 150 136
120 113
106 109 112
100 91 90 92 100 97 90
87 86 87 88 84 86 88 888489 86 82 889181
81 7876 79 77 80
67 70 70 69 68 68 7375 74
62 59
48 49 51
50 39 43 50
OOM OOM
512 1k 2k 4k 8k 16k 512 1k 2k 4k 8k 16k
Sequence length Sequence length
(a) Without causal mask, head dimension 64 (b) Without causal mask, head dimension 128
Attention backward speed (A100 80GB SXM4) Attention backward speed (A100 80GB SXM4)
Pytorch Pytorch
FlashAttention FlashAttention
200 xformers 200 xformers 186
176
FlashAttention Triton FlashAttention Triton
Speed (TFLOPs/s)
Speed (TFLOPs/s)
160 166 165
FlashAttention-2 149 FlashAttention-2 145
150 150
131
122
111
98
100 85
93 100 90 8684 8984
81 76 71 8080
70 68 68 7175 67
58 53 60 6265 6267 60 5953 6365 58 63 66
54 52 49
50 46 50 43 43 45
30 37
19 21 24 25 26
OOM OOM
512 1k 2k 4k 8k 16k 512 1k 2k 4k 8k 16k
Sequence length Sequence length
(c) With causal mask, head dimension 64 (d) With causal mask, head dimension 128
We are excited about how this can be used to understand long books and reports, high resolution images,
audio and video. FlashAttention-2 will also speed up training, finetuning, and inference of existing models.
In the near future, we plan to collaborate with researchers and engineers to make FlashAttention widely
applicable in different kinds of devices (e.g., H100 GPUs, AMD GPUs), as well as new data types such as
FP8. As an immediate next step, we plan to optimize FlashAttention-2 for H100 GPUs to use new hardware
features (TMA, 4th-gen Tensor Cores, fp8). Combining the low-level optimizations in FlashAttention-2 with
high-level algorithmic changes (e.g., local, dilated, block-sparse attention) could allow us to train AI models
with much longer context. We are also excited to work with compiler researchers to make these optimization
techniques easily programmable.
12
Attention forward + backward speed (H100 80GB SXM5) Attention forward + backward speed (H100 80GB SXM5)
Pytorch Pytorch
FlashAttention FlashAttention 320 326 335 338
FlashAttention-2 294 296 FlashAttention-2
294
300 288 300
Speed (TFLOPs/s)
Speed (TFLOPs/s)
274
254 248
215
200 200
157 159 161 161 166 168 160 167
145 137 139
127 120127 128 131
100 86 87 100 93
72 81
62
OOM OOM
512 1k 2k 4k 8k 16k 512 1k 2k 4k 8k 16k
Sequence length Sequence length
(a) Without causal mask, head dimension 64 (b) Without causal mask, head dimension 128
Attention forward + backward speed (H100 80GB SXM5) Attention forward + backward speed (H100 80GB SXM5)
Pytorch Pytorch
FlashAttention FlashAttention 328
FlashAttention-2 FlashAttention-2 294
308
300 284 300
Speed (TFLOPs/s)
Speed (TFLOPs/s)
273 265
257
232 221
200 192 200
156 163
141 138 149 137
136 126 135
123
104 98 109 108
100 100
57 61 63
40 50
26 29 31 32 32
OOM OOM
512 1k 2k 4k 8k 16k 512 1k 2k 4k 8k 16k
Sequence length Sequence length
(c) With causal mask, head dimension 64 (d) With causal mask, head dimension 128
Acknowledgments
We thank Phil Tillet and Daniel Haziza, who have implemented versions of FlashAttention in Triton [17] and
the xformers library [10]. FlashAttention-2 was motivated by exchange of ideas between different ways that
attention could be implemented. We are grateful to the Nvidia CUTLASS team (especially Vijay Thakkar, Cris
Cecka, Haicheng Wu, and Andrew Kerr) for their CUTLASS library, in particular the CUTLASS 3.x release,
which provides clean abstractions and powerful building blocks for the implementation of FlashAttention-2.
We thank Driss Guessous for integrating FlashAttention to PyTorch. FlashAttention-2 has benefited
from helpful discussions with Phil Wang, Markus Rabe, James Bradbury, Young-Jun Ko, Julien Launay,
Daniel Hesslow, Michaël Benesty, Horace He, Ashish Vaswani, and Erich Elsen. Thanks to Stanford CRFM
and Stanford NLP for the compute support. We thank Dan Fu and Christopher Ré for their collaboration,
constructive feedback, and constant encouragement on this line of work of designing hardware-efficient
algorithms. We thank Albert Gu and Beidi Chen for their helpful suggestions on early drafts of this technical
report.
References
[1] Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit
Sanghai. Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv
preprint arXiv:2305.13245, 2023.
[2] Iz Beltagy, Matthew E Peters, and Arman Cohan. Longformer: The long-document transformer. arXiv
preprint arXiv:2004.05150, 2020.
13
[3] Beidi Chen, Tri Dao, Eric Winsor, Zhao Song, Atri Rudra, and Christopher Ré. Scatterbrain: Unifying
sparse and low-rank attention. In Advances in Neural Information Processing Systems (NeurIPS), 2021.
[4] Krzysztof Marcin Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane,
Tamas Sarlos, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking
attention with performers. In International Conference on Learning Representations (ICLR), 2020.
[5] Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. FlashAttention: Fast and
memory-efficient exact attention with IO-awareness. In Advances in Neural Information Processing
Systems, 2022.
[6] Zhe Jia and Peter Van Sandt. Dissecting the Ampere GPU architecture via microbenchmarking. GPU
Technology Conference, 2021.
[7] Zhe Jia, Marco Maggioni, Benjamin Staiger, and Daniele P Scarpazza. Dissecting the nvidia Volta GPU
architecture via microbenchmarking. arXiv preprint arXiv:1804.06826, 2018.
[8] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are RNNs:
Fast autoregressive transformers with linear attention. In International Conference on Machine Learning,
pages 5156–5165. PMLR, 2020.
[9] Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya. Reformer: The efficient transformer. In The
International Conference on Machine Learning (ICML), 2020.
[10] Benjamin Lefaudeux, Francisco Massa, Diana Liskovich, Wenhan Xiong, Vittorio Caggiano, Sean Naren,
Min Xu, Jieru Hu, Marta Tintore, Susan Zhang, Patrick Labatut, and Daniel Haziza. xformers: A modular
and hackable transformer modelling library. https://github.com/facebookresearch/xformers, 2022.
[11] Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax. arXiv preprint
arXiv:1805.02867, 2018.
[12] OpenAI. Gpt-4 technical report. ArXiv, abs/2303.08774, 2023.
[13] Markus N Rabe and Charles Staats. Self-attention does not need 𝑂 (𝑛2 ) memory. arXiv preprint
arXiv:2112.05682, 2021.
[14] Aurko Roy, Mohammad Saffar, Ashish Vaswani, and David Grangier. Efficient content-based sparse
attention with routing transformers. Transactions of the Association for Computational Linguistics, 9:
53–68, 2021.
[15] Noam Shazeer. Fast transformer decoding: One write-head is all you need. arXiv preprint
arXiv:1911.02150, 2019.
[16] Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper, and Bryan Catanzaro.
Megatron-LM: Training multi-billion parameter language models using model parallelism. arXiv preprint
arXiv:1909.08053, 2019.
[17] Philippe Tillet, Hsiang-Tsung Kung, and David Cox. Triton: an intermediate language and compiler for
tiled neural network computations. In Proceedings of the 3rd ACM SIGPLAN International Workshop
on Machine Learning and Programming Languages, pages 10–19, 2019.
[18] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz
Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing
systems, 30, 2017.
[19] Sinong Wang, Belinda Z Li, Madian Khabsa, Han Fang, and Hao Ma. Linformer: Self-attention with
linear complexity. arXiv preprint arXiv:2006.04768, 2020.
[20] Manzil Zaheer, Guru Guruganesh, Kumar Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago
Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, et al. Big bird: Transformers for longer
sequences. Advances in Neural Information Processing Systems, 33, 2020.
14