Fast Transformer Decoding - One Write-Head Is All You Need
Fast Transformer Decoding - One Write-Head Is All You Need
You Need
Noam Shazeer
Google
[email protected]
arXiv:1911.02150v1 [cs.NE] 6 Nov 2019
November 7, 2019
Abstract
Multi-head attention layers, as used in the Transformer neural sequence model, are a powerful alter-
native to RNNs for moving information across and between sequences. While training these layers is
generally fast and simple, due to parallelizability across the length of the sequence, incremental inference
(where such paralleization is impossible) is often slow, due to the memory-bandwidth cost of repeatedly
loading the large "keys" and "values" tensors. We propose a variant called multi-query attention, where
the keys and values are shared across all of the different attention "heads", greatly reducing the size of
these tensors and hence the memory bandwidth requirements of incremental decoding. We verify exper-
imentally that the resulting models can indeed be much faster to decode, and incur only minor quality
degradation from the baseline.
1 Introduction
The Transformer neural sequence model [Vaswani et al., 2017] has emerged as a popular alternative to
recurrent sequence models. Transformer relies on attention layers to communicate information between
and across sequences. One major challenge with Transformer is the speed of incremental inference. As we
will discuss, the speed of incremental Transformer inference on modern computing hardware is limited by
the memory bandwidth necessary to reload the large "keys" and "values" tensors which encode the state
of the attention layers. In the following sections, we will review the multi-head-attention layers used by
Transformer, provide a performance analysis, and propose an architectural variation (multi-query attention)
which greatly improves inference speed with only minor quality degradation.
1
d e f Do tP r o ductAttentio n ( q , K, V ) :
" " " Dot−Product A t t e n t i o n on one quer y .
Args :
q : a v e c t o r with sha pe [ k ]
K: a ma tr ix with sha pe [m, k ]
V: a ma tr ix with sha pe [m, v ]
Retur ns :
y : a v e c t o r with sha pe [ v ]
"""
l o g i t s = t f . einsum ( " k , mk−>m" , q , K)
w e i g h t s = t f . so ftma x ( l o g i t s )
r e t u r n t f . einsum ( "m, mv−>v " , weig hts , V)
Our code samples use einsum notation, as defined in TensorFlow and numpy, for generalized contractions
between tensors of arbitrary dimension. In this notation, an equation names the dimensions of the input and
output Tensors. The computation is numerically equivalent to broadcasting each input to have the union of
all dimensions, multiplying component-wise, and summing across all dimensions not in the desired output
shape.
2
2.3 Multi-head Attention (Batched)
In practice, it is far more efficient to batch together multiple queries. The code below adds two types of
batching. First, we generate queries from n different positions in a sequence. These queries all interact with
the same keys and values. In addition, we process a batch of b different non-interacting sequences at once.
Following [Vaswani et al., 2017], in an autoregressive model, we can prevent backward-information-flow by
adding a "mask" to the logits containing the value −∞ in the illegal positions.
d e f M u l t i h e a d A t t e n t i o n Ba t ch ed (
X, M, mask , P_q, P_k, P_v, P_o ) :
" " " Multi−head A t t e n t i o n .
Args :
X: a t e n s o r with sha pe [ b , n , d ]
M: a t e n s o r with sha pe [ b , m, d ]
mask : a t e n s o r with sha pe [ b , h , n , m]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ h , d , k ]
P_v: a t e n s o r with sha pe [ h , d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
Y: a t e n s o r with sha pe [ b , n , d ]
"""
Q = t f . einsum ( " bnd , hdk−>bhnk " , X, P_q)
K = t f . einsum ( "bmd, hdk−>bhmk" , M, P_k)
V = t f . einsum ( "bmd, hdv−>bhmv" , M, P_v)
l o g i t s = t f . einsum ( " bhnk , bhmk−>bhnm " , Q, K)
w e i g h t s = t f . so ftma x ( l o g i t s + mask )
O = t f . einsum ( "bhnm , bhmv−>bhnv " , weig hts , V)
Y = t f . einsum ( " bhnv , hdv−>bnd " , O, P_o)
return Y
• n≤d
The total number of arithmetic operations is Θ(bnd2 ). (Since the complexity of each of the tf.einsum
operations above is O(bnd2 ) given the simplifying assumptions.
The total size of memory to be accessed is equal to the sum of the sizes of all the tensors involved:
O(bnd + bhn2 + d2 ). The first term is due to X, M , Q, K, V , O and Y , the second term due to the logits
and weights, and the third term due to the projection tensors Pq , Pk , Pv and Po .
1
Dividing the two, we find that the ratio of memory access to arithmetic operations is O( k1 + bn ). This low
ratio is necessary for good performance on modern GPU/TPU hardware, where the computational capacity
can be two orders of magnitude higher than the memory bandwidth.
3
efficient parallel implementation similar to that in section 2.3. However, when generating from the trained
model, the output of the self-attention layer at a particular position affects the token that is generated at
the next position, which in turn affects the input to that layer at the next position. This prevents parallel
computation. Code for incrementally computing this self-attention layer is shown below.
def MultiheadSelfAttentionIncremental (
x , prev_K , prev_V , P_q , P_k , P_v , P_o ) :
" " " Multi−head S e l f −A t t e n t i o n ( one s t e p ) .
Args :
x : a t e n s o r with sha pe [ b , d ]
prev_K : t e n s o r with sha pe [ b , h , m, k ]
prev_V : t e n s o r with sha pe [ b , h , m, v ]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ h , d , k ]
P_v: a t e n s o r with sha pe [ h , d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
y : a t e n s o r with sha pe [ b , d ]
new_K: t e n s o r with sha pe [ b , h , m+1 , k ]
new_V : t e n s o r with sha pe [ b , h , m+1 , v ]
"""
q = t f . einsum ( " bd , hdk−>bhk " , x , P_q)
new_K = t f . c o n c a t (
[ prev_K , t f . expand_dims ( t f . einsum ( " bd , hdk−>bhk " , M, P_k) , a x i s = 2 ) ] ,
a x i s =2)
new_V = t f . c o n c a t (
[ prev_V , t f . expand_dims ( t f . einsum ( " bd , hdv−>bhv " , M, P_v) , a x i s = 2 ) ] ,
a x i s =2)
l o g i t s = t f . einsum ( " bhk , bhmk−>bhm" , q , new_K)
w e i g h t s = t f . so ftma x ( l o g i t s )
o = t f . einsum ( "bhm, bhmv−>bhv " , weig hts , new_V)
y = t f . einsum ( " bhv , hdv−>bd " , O, P_o)
r e t u r n y , new_K, new_V
4
3 Multi-Query Attention
We introduce multi-query Attention as a variation of multi-head attention as described in [Vaswani et al.,
2017]. Multi-head attention consists of multiple attention layers (heads) in parallel with different linear
transformations on the queries, keys, values and outputs. Multi-query attention is identical except that the
different heads share a single set of keys and values. The code for (incremental) multi-query (self) attention
is identical to the code listed above for multi-head attention, except that we remove the letter "h" from the
tf.einsum equations where it represents the "heads" dimension of K, V , Pk , or Pv .
def MultiqueryAttentionBatched (
X, M, mask , P_q, P_k, P_v, P_o ) :
" " " Multi−Query A t t e n t i o n .
Args :
X: a t e n s o r with sha pe [ b , n , d ]
M: a t e n s o r with sha pe [ b , m, d ]
mask : a t e n s o r with sha pe [ b , h , n , m]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ d , k ]
P_v: a t e n s o r with sha pe d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
Y: a t e n s o r with sha pe [ b , n , d ]
"""
Q = t f . einsum ( " bnd , hdk−>bhnk " , X, P_q)
K = t f . einsum ( "bmd, dk−>bmk" , M, P_k)
V = t f . einsum ( "bmd, dv−>bmv" , M, P_v)
l o g i t s = t f . einsum ( " bhnk , bmk−>bhnm " , Q, K)
w e i g h t s = t f . so ftma x ( l o g i t s + mask )
O = t f . einsum ( "bhnm , bmv−>bhnv " , weig hts , V)
Y = t f . einsum ( " bhnv , hdv−>bnd " , O, P_o)
return Y
5
def MultiquerySelfAttentionIncremental (
x , prev_K , prev_V , P_q , P_k , P_v , P_o ) :
" " " Multi−quer y S e l f −A t t e n t i o n ( one s t e p ) .
Args :
x : a t e n s o r with sha pe [ b , d ]
prev_K : t e n s o r with sha pe [ b , m, k ]
prev_V : t e n s o r with sha pe [ b , m, v ]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ d , k ]
P_v: a t e n s o r with sha pe [ d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
y : a t e n s o r with sha pe [ b , d ]
new_K: t e n s o r with sha pe [ b , m+1 , k ]
new_V : t e n s o r with sha pe [ b , m+1 , v ]
"""
q = t f . einsum ( " bd , hdk−>bhk " , x , P_q)
K = t f . concat (
[ prev_K , t f . expand_dims ( t f . einsum ( " bd , dk−>bk " , M, P_k) , a x i s = 2 ) ] ,
a x i s =2)
V = t f . concat (
[ prev_V , t f . expand_dims ( t f . einsum ( " bd , dv−>bv " , M, P_v) , a x i s = 2 ) ] ,
a x i s =2)
l o g i t s = t f . einsum ( " bhk , bmk−>bhm" , q , K)
w e i g h t s = t f . so ftma x ( l o g i t s )
o = t f . einsum ( "bhm, bmv−>bhv " , weig hts , V)
y = t f . einsum ( " bhv , hdv−>bd " , O, P_o)
r e t u r n y , K, V
6
The configurations used can be found at [to be added before publication] , including details about learning
rates, dropout, label smoothing, etc.
In our "multi-query" model, we replace all of the attention layers in the model to multi-query attention.
This includes the encoder-self-attention, decoder-self-attention and encoder-decoder-attention layers. We
widen the feed-forward hidden layers from 4096 to 5440 to make the total parameter-count equal to that of
the baseline.
To demonstrate that local-attention and multi-query attention are orthogonal, we also trained "local"
versions of the baseline and multi-query models, where the decoder-self-attention layers (but not the other
attention layers) restrict attention to the current position and the previous 31 positions.
A simpler alternative way to reduce the sizes of K and V is to reduce the number of heads h and/or to
reduce the dimensionalities k and v of the keys and values. We trained several such models for comparison,
again widening the feed-forward hidden layers to make the total parameter-count equal to that of the baseline.
We preformed a similar set of experiments using "transformer-decoder" language models on the Billion-
Word Language Modeling Benchmark [Chelba et al., 2013]. For the baseline, we use a model with 6 layers,
dmodel = 1024 df f = 8192, h = 8, dk = dv = 128. The total parameter count is 192 million for the baseline
and for all variations. We trained for 136K steps (10 epochs) at a batch size of 64K tokens. Again, we used
a 32-core TPUv3 cluster for approximately 3 hours to train each model.
4.3 Speed
Table 2 shows training and inference times for the various models. Both training and inference speeds were
evaluated on one TPUv2 (8 cores). A training step (consisting of 32,768 input tokens and 32,768 target
tokens, as described above) took 433ms for the base model and 425ms for the multi-query model. Dividing
by 32,768, we find that the training time is 13.2µs per (input-token + target-token), as listed in Table 2.
We ran incremental greedy inference on a batch of 1024 sequences (128 per core) using a source-sequence
length of 128 tokens and a target sequence length of 128. 1 For the baseline model, the encoder part of the
model took 222ms and each incremental step of the decoder took 47ms. Dividing by the respective numbers
of tokens, we find that the amortized inference time is 1.7µs per token for the encoder and a much larger
46µs per token for the decoder, as listed in Table 2. For the multi-query model, the encoder took 195ms
and the decoder took 3.9ms per step, for amortized per-token costs of 1.5µs and 3.8µs respectively. Table 2
shows these values as well as similar results for beam-search.
1 Due to system limitations requiring fixed shapes, we used padding and masking in our decoder-self-attention implementation.
The memory tensors were thus padded to the maximum length (128), or to the window-size (32) in the case of local attention.
Each decoding step thus took the same amount of time. An alternative implementation of incrementally growing the tensors
could save time near the beginning of the sequence.
7
Table 1: WMT14 EN-DE Results.
Table 2: Amortized training and inference costs for WMT14 EN-DE Translation Task with sequence length
128. Values listed are in TPUv2-microseconds per output token.
Attention h dk , dv df f dev-PPL
multi-head 8 128 8192 29.9
multi-query 8 128 9088 30.2
multi-head 1 128 9984 31.2
multi-head 2 64 9984 31.1
multi-head 4 32 9984 31.0
multi-head 8 16 9984 30.9
5 Conclusion
We have proposed multi-query attention - an alternative to multi-head attention with much lower memory-
bandwidth requirements in the incremental setting. We believe that this enables wider adoption of attention-
based sequence models in inference-performance-critical applications.
References
Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to
align and translate, 2014.
Ciprian Chelba, Tomas Mikolov, Mike Schuster, Qi Ge, Thorsten Brants, and Phillipp Koehn. One billion
word benchmark for measuring progress in statistical language modeling. CoRR, abs/1312.3005, 2013.
URL http://arxiv.org/abs/1312.3005.
Peter J Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, Ryan Sepassi, Lukasz Kaiser, and Noam Shazeer.
8
Generating wikipedia by summarizing long sequences. In Proceedings of the International Conference on
Learning Representations, 2018.
Daniel Povey, Hossein Hadian, Pegah Ghahremani, Ke Li, and Sanjeev Khudanpur. A time-restricted self-
attention layer for ASR. In Proceddings of the IEEE International Conference on Acoustics, Speech and
Signal Processing (ICASSP). IEEE, 2018.
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser,
and Illia Polosukhin. Attention is all you need. In NIPS, 2017.
Biao Zhang, Deyi Xiong, and Jinsong Su. Accelerating neural transformer via an average attention network,
2018.