0% found this document useful (0 votes)
18 views9 pages

Fast Transformer Decoding - One Write-Head Is All You Need

The document discusses fast transformer decoding using multi-query attention. Standard transformer models use multi-head attention which is fast for training but slow for incremental inference due to repeatedly loading large key and value tensors. The paper proposes multi-query attention which shares keys and values across attention heads, greatly reducing tensor sizes and memory bandwidth requirements, allowing much faster incremental decoding with minor quality loss.

Uploaded by

Agarwal Shubham
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
18 views9 pages

Fast Transformer Decoding - One Write-Head Is All You Need

The document discusses fast transformer decoding using multi-query attention. Standard transformer models use multi-head attention which is fast for training but slow for incremental inference due to repeatedly loading large key and value tensors. The paper proposes multi-query attention which shares keys and values across attention heads, greatly reducing tensor sizes and memory bandwidth requirements, allowing much faster incremental decoding with minor quality loss.

Uploaded by

Agarwal Shubham
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 9

Fast Transformer Decoding: One Write-Head is All

You Need

Noam Shazeer
Google
[email protected]
arXiv:1911.02150v1 [cs.NE] 6 Nov 2019

November 7, 2019

Abstract
Multi-head attention layers, as used in the Transformer neural sequence model, are a powerful alter-
native to RNNs for moving information across and between sequences. While training these layers is
generally fast and simple, due to parallelizability across the length of the sequence, incremental inference
(where such paralleization is impossible) is often slow, due to the memory-bandwidth cost of repeatedly
loading the large "keys" and "values" tensors. We propose a variant called multi-query attention, where
the keys and values are shared across all of the different attention "heads", greatly reducing the size of
these tensors and hence the memory bandwidth requirements of incremental decoding. We verify exper-
imentally that the resulting models can indeed be much faster to decode, and incur only minor quality
degradation from the baseline.

1 Introduction
The Transformer neural sequence model [Vaswani et al., 2017] has emerged as a popular alternative to
recurrent sequence models. Transformer relies on attention layers to communicate information between
and across sequences. One major challenge with Transformer is the speed of incremental inference. As we
will discuss, the speed of incremental Transformer inference on modern computing hardware is limited by
the memory bandwidth necessary to reload the large "keys" and "values" tensors which encode the state
of the attention layers. In the following sections, we will review the multi-head-attention layers used by
Transformer, provide a performance analysis, and propose an architectural variation (multi-query attention)
which greatly improves inference speed with only minor quality degradation.

2 Background: Neural Attention


Neural Attention, introduced by [Bahdanau et al., 2014], is a powerful tool for manipulating variable-length
representations. A neural attention function takes a single query-vector q and a set of m different (key-vector,
value-vector) pairs (represented by the matrices K and V ), and produces an output vector y. The output y
is computed as a weighted sum of the different value vectors, where the weights are derived by comparing
the query to the keys.

2.1 Dot-Product Attention


The following code describes a common formulation, where the weights are computed as the softmax of the
dot-products of the query with the different keys.

1
d e f Do tP r o ductAttentio n ( q , K, V ) :
" " " Dot−Product A t t e n t i o n on one quer y .
Args :
q : a v e c t o r with sha pe [ k ]
K: a ma tr ix with sha pe [m, k ]
V: a ma tr ix with sha pe [m, v ]
Retur ns :
y : a v e c t o r with sha pe [ v ]
"""
l o g i t s = t f . einsum ( " k , mk−>m" , q , K)
w e i g h t s = t f . so ftma x ( l o g i t s )
r e t u r n t f . einsum ( "m, mv−>v " , weig hts , V)
Our code samples use einsum notation, as defined in TensorFlow and numpy, for generalized contractions
between tensors of arbitrary dimension. In this notation, an equation names the dimensions of the input and
output Tensors. The computation is numerically equivalent to broadcasting each input to have the union of
all dimensions, multiplying component-wise, and summing across all dimensions not in the desired output
shape.

2.2 Multi-head Attention


The "Transformer" seuqence-to-sequence model [Vaswani et al., 2017] uses h different attention layers (heads)
in parallel, which the authors refer to as "Multi-head attention". The query vectors for the h different layers
are derived from h different learned linear projections Pq of an input vector x. Similarly, the keys and
values are derived from h different learned linear projections Pk , Pv of a collection M of m different input
vectors. The outputs of the h layers are themselves passed through different learned linear projections Po ,
then summed. For simplicity, we give the input and output vectors identical dimensionality d. The The
computation can be expressed as follows:
def MultiheadAttention (
x , M, P_q, P_k, P_v, P_o ) :
" " " Multi−head A t t e n t i o n on one quer y .
Args :
x : a v e c t o r with sha pe [ d ]
M: a ma tr ix with sha pe [m, d ]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ h , d , k ]
P_v: a t e n s o r with sha pe [ h , d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
y : a v e c t o r with sha pe [ d ]
"""
q = t f . einsum ( " d , hdk−>hk " , x , P_q)
K = t f . einsum ( "md, hdk−>hmk" , M, P_k)
V = t f . einsum ( "md, hdv−>hmv" , M, P_v)
l o g i t s = t f . einsum ( " hk , hmk−>hm" , q , K)
w e i g h t s = t f . so ftma x ( l o g i t s )
o = t f . einsum ( "hm, hmv−>hv " , weig hts , V)
y = t f . einsum ( " hv , hdv−>d " , o , P_o)
return y
Note: [Vaswani et al., 2017] include a constant scaling factor on the logits. We omit this in our code, as
it can be folded into the linear projections Pq or Pk .

2
2.3 Multi-head Attention (Batched)
In practice, it is far more efficient to batch together multiple queries. The code below adds two types of
batching. First, we generate queries from n different positions in a sequence. These queries all interact with
the same keys and values. In addition, we process a batch of b different non-interacting sequences at once.
Following [Vaswani et al., 2017], in an autoregressive model, we can prevent backward-information-flow by
adding a "mask" to the logits containing the value −∞ in the illegal positions.
d e f M u l t i h e a d A t t e n t i o n Ba t ch ed (
X, M, mask , P_q, P_k, P_v, P_o ) :
" " " Multi−head A t t e n t i o n .
Args :
X: a t e n s o r with sha pe [ b , n , d ]
M: a t e n s o r with sha pe [ b , m, d ]
mask : a t e n s o r with sha pe [ b , h , n , m]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ h , d , k ]
P_v: a t e n s o r with sha pe [ h , d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
Y: a t e n s o r with sha pe [ b , n , d ]
"""
Q = t f . einsum ( " bnd , hdk−>bhnk " , X, P_q)
K = t f . einsum ( "bmd, hdk−>bhmk" , M, P_k)
V = t f . einsum ( "bmd, hdv−>bhmv" , M, P_v)
l o g i t s = t f . einsum ( " bhnk , bhmk−>bhnm " , Q, K)
w e i g h t s = t f . so ftma x ( l o g i t s + mask )
O = t f . einsum ( "bhnm , bhmv−>bhnv " , weig hts , V)
Y = t f . einsum ( " bhnv , hdv−>bnd " , O, P_o)
return Y

2.3.1 Performance Analysis of Batched Multi-head Attention


To simplify the performance analysis, we will make several simplifying assumptions:
• m=n
• k = v = hd , as suggested by [Vaswani et al., 2017]

• n≤d
The total number of arithmetic operations is Θ(bnd2 ). (Since the complexity of each of the tf.einsum
operations above is O(bnd2 ) given the simplifying assumptions.
The total size of memory to be accessed is equal to the sum of the sizes of all the tensors involved:
O(bnd + bhn2 + d2 ). The first term is due to X, M , Q, K, V , O and Y , the second term due to the logits
and weights, and the third term due to the projection tensors Pq , Pk , Pv and Po .
1
Dividing the two, we find that the ratio of memory access to arithmetic operations is O( k1 + bn ). This low
ratio is necessary for good performance on modern GPU/TPU hardware, where the computational capacity
can be two orders of magnitude higher than the memory bandwidth.

2.4 Multihead Attention (Incremental)


In some settings, data dependencies make it is impossible to process queries from multiple positions in parallel.
An example is a self-attention layer in an autoregressive language model such as Transformer [Vaswani et al.,
2017]. The queries produced at each position attend to key-value pairs produced at all positions up to and
including that position. During training, the ground-truth target sequence is known, and we can use an

3
efficient parallel implementation similar to that in section 2.3. However, when generating from the trained
model, the output of the self-attention layer at a particular position affects the token that is generated at
the next position, which in turn affects the input to that layer at the next position. This prevents parallel
computation. Code for incrementally computing this self-attention layer is shown below.
def MultiheadSelfAttentionIncremental (
x , prev_K , prev_V , P_q , P_k , P_v , P_o ) :
" " " Multi−head S e l f −A t t e n t i o n ( one s t e p ) .
Args :
x : a t e n s o r with sha pe [ b , d ]
prev_K : t e n s o r with sha pe [ b , h , m, k ]
prev_V : t e n s o r with sha pe [ b , h , m, v ]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ h , d , k ]
P_v: a t e n s o r with sha pe [ h , d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
y : a t e n s o r with sha pe [ b , d ]
new_K: t e n s o r with sha pe [ b , h , m+1 , k ]
new_V : t e n s o r with sha pe [ b , h , m+1 , v ]
"""
q = t f . einsum ( " bd , hdk−>bhk " , x , P_q)
new_K = t f . c o n c a t (
[ prev_K , t f . expand_dims ( t f . einsum ( " bd , hdk−>bhk " , M, P_k) , a x i s = 2 ) ] ,
a x i s =2)
new_V = t f . c o n c a t (
[ prev_V , t f . expand_dims ( t f . einsum ( " bd , hdv−>bhv " , M, P_v) , a x i s = 2 ) ] ,
a x i s =2)
l o g i t s = t f . einsum ( " bhk , bhmk−>bhm" , q , new_K)
w e i g h t s = t f . so ftma x ( l o g i t s )
o = t f . einsum ( "bhm, bhmv−>bhv " , weig hts , new_V)
y = t f . einsum ( " bhv , hdv−>bd " , O, P_o)
r e t u r n y , new_K, new_V

2.4.1 Performance Analysis


We make the same simplifying assumptions as in section 2.3.1.
Across n calls, the total number of arithmetic operations is again Θ(bnd2 ).
Across n calls, the total amount of memory access is Θ(bn2 d + nd2 ), the first term due to K and V and
the second term due to Pq , Pk , Pv and Po .
Dividing the memory by the computations, we find that the ratio of memory access to arithmetic opera-
tions is Θ( nd + 1b ). When n ≈ d or b ≈ 1, the ratio is close to 1, causing memory bandwidth to be a major
performance bottleneck on modern computing hardware. In order to make incremental generation efficient,
we must reduce both of these terms to be ≪ 1. The 1b term is the easier one - we can just use a larger batch
size, memory size permitting.
Reducing the nd term is harder. This term is related to the expense of reloading at each step the K and V
tensors representing the memory which have size bhmk = bn2 . One solution is to limit the sequence length n.
Another is to reduce the number of positions being attended-to, either by attending to a local neighborhood,
or by otherwise compressing the number of memory positions, as in [Liu et al., 2018], [Zhang et al., 2018],
[Povey et al., 2018]. In this paper we present an orthogonal approach to reducing the size of the K and V
tensors - namely removing their "heads" dimension, while maintaining the "heads" dimension in the queries.

4
3 Multi-Query Attention
We introduce multi-query Attention as a variation of multi-head attention as described in [Vaswani et al.,
2017]. Multi-head attention consists of multiple attention layers (heads) in parallel with different linear
transformations on the queries, keys, values and outputs. Multi-query attention is identical except that the
different heads share a single set of keys and values. The code for (incremental) multi-query (self) attention
is identical to the code listed above for multi-head attention, except that we remove the letter "h" from the
tf.einsum equations where it represents the "heads" dimension of K, V , Pk , or Pv .
def MultiqueryAttentionBatched (
X, M, mask , P_q, P_k, P_v, P_o ) :
" " " Multi−Query A t t e n t i o n .
Args :
X: a t e n s o r with sha pe [ b , n , d ]
M: a t e n s o r with sha pe [ b , m, d ]
mask : a t e n s o r with sha pe [ b , h , n , m]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ d , k ]
P_v: a t e n s o r with sha pe d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
Y: a t e n s o r with sha pe [ b , n , d ]
"""
Q = t f . einsum ( " bnd , hdk−>bhnk " , X, P_q)
K = t f . einsum ( "bmd, dk−>bmk" , M, P_k)
V = t f . einsum ( "bmd, dv−>bmv" , M, P_v)
l o g i t s = t f . einsum ( " bhnk , bmk−>bhnm " , Q, K)
w e i g h t s = t f . so ftma x ( l o g i t s + mask )
O = t f . einsum ( "bhnm , bmv−>bhnv " , weig hts , V)
Y = t f . einsum ( " bhnv , hdv−>bnd " , O, P_o)
return Y

5
def MultiquerySelfAttentionIncremental (
x , prev_K , prev_V , P_q , P_k , P_v , P_o ) :
" " " Multi−quer y S e l f −A t t e n t i o n ( one s t e p ) .
Args :
x : a t e n s o r with sha pe [ b , d ]
prev_K : t e n s o r with sha pe [ b , m, k ]
prev_V : t e n s o r with sha pe [ b , m, v ]
P_q: a t e n s o r with sha pe [ h , d , k ]
P_k: a t e n s o r with sha pe [ d , k ]
P_v: a t e n s o r with sha pe [ d , v ]
P_o : a t e n s o r with sha pe [ h , d , v ]
Retur ns :
y : a t e n s o r with sha pe [ b , d ]
new_K: t e n s o r with sha pe [ b , m+1 , k ]
new_V : t e n s o r with sha pe [ b , m+1 , v ]
"""
q = t f . einsum ( " bd , hdk−>bhk " , x , P_q)
K = t f . concat (
[ prev_K , t f . expand_dims ( t f . einsum ( " bd , dk−>bk " , M, P_k) , a x i s = 2 ) ] ,
a x i s =2)
V = t f . concat (
[ prev_V , t f . expand_dims ( t f . einsum ( " bd , dv−>bv " , M, P_v) , a x i s = 2 ) ] ,
a x i s =2)
l o g i t s = t f . einsum ( " bhk , bmk−>bhm" , q , K)
w e i g h t s = t f . so ftma x ( l o g i t s )
o = t f . einsum ( "bhm, bmv−>bhv " , weig hts , V)
y = t f . einsum ( " bhv , hdv−>bd " , O, P_o)
r e t u r n y , K, V

3.1 Performance Analysis for Incremental Multi-Query Attention


We make the same simplifying assumptions as in section 2.3.1.
Across n calls, the total number of arithmetic operations is again Θ(bnd2 ).
Across n calls, the total amount of memory access is Θ(bnd + bn2 k + nd2 ), the first term due to x, q, o
and y, the second term due to K and V and the third term due to Pq , Pk , Pv , Po .
Dividing the memory by the computations, we find that the ratio of memory access to arithmetic opera-
n
tions is Θ( d1 + dh + 1b ). We have reduced the offensive nd by a factor of h. Theoretically, given large batch
size b, this should dramatically improve performance of incremental generation. In our experimental section,
we will show that the performance gains are real and that model quality remains high.

4 Experiments and Results


4.1 Experimental Setup
Following [Vaswani et al., 2017], we evaluate on the WMT 2014 English-German translation task. As a
baseline, we use an encoder-decoder Transformer model with 6 layers, using dmodel = 1024 df f = 4096,
h = 8, dk = dv = 128, learned positional embeddings, and weight-sharing between the token-embedding and
output layers. The baseline model and all variations have 211 million parameters. All models were trained
for 100,000 steps ( 20 epochs). Each training batch consisted of 128 examples, each of which consisted of
a 256-token input sequence and a 256-token target sequence (multiple training sentences were concatenated
together to reach this length). Models were trained on a 32-core TPUv3 cluster, with each model taking
about 2 hours to train. We used an implementation from the tensor2tensor and mesh-tensorflow libraries.

6
The configurations used can be found at [to be added before publication] , including details about learning
rates, dropout, label smoothing, etc.
In our "multi-query" model, we replace all of the attention layers in the model to multi-query attention.
This includes the encoder-self-attention, decoder-self-attention and encoder-decoder-attention layers. We
widen the feed-forward hidden layers from 4096 to 5440 to make the total parameter-count equal to that of
the baseline.
To demonstrate that local-attention and multi-query attention are orthogonal, we also trained "local"
versions of the baseline and multi-query models, where the decoder-self-attention layers (but not the other
attention layers) restrict attention to the current position and the previous 31 positions.
A simpler alternative way to reduce the sizes of K and V is to reduce the number of heads h and/or to
reduce the dimensionalities k and v of the keys and values. We trained several such models for comparison,
again widening the feed-forward hidden layers to make the total parameter-count equal to that of the baseline.
We preformed a similar set of experiments using "transformer-decoder" language models on the Billion-
Word Language Modeling Benchmark [Chelba et al., 2013]. For the baseline, we use a model with 6 layers,
dmodel = 1024 df f = 8192, h = 8, dk = dv = 128. The total parameter count is 192 million for the baseline
and for all variations. We trained for 136K steps (10 epochs) at a batch size of 64K tokens. Again, we used
a 32-core TPUv3 cluster for approximately 3 hours to train each model.

4.2 Model Quality


Table 1 shows results for the machine-translation experiments. We decoded the dev set using greedy
maximum-likelihood decoding and computed BLEU score with sacrebleu "sacrebleu -t wmt13 -l en-de
-tok intl". We also list per-subword-token perplexity on the dev set. According to both of these metrics,
the multi-query attention model seems to be slightly worse than the baseline, but much closer than any of
the alternatives involving decreasing h, dk and dv .
We validated the results by decoding the test set using both greedy decoding and beam search (beam
4, α = 0.6), and evaluated with sacrebleu "sacrebleu -t wmt14 -l en-de -tok intl". Again, the multi-
query model performed similarly to the baseline, and actually had the highest BLEU score (28.5) with
beam-4 decoding.
Table 3 shows results for the billion-word language modeling benchmark. Models were evaluated by per-
word (not per-subword-token) perplexity on the dev set. The results paint a similar picture to the translation
results. The multi-query attention model was slightly worse than the baseline, but significantly better than
any of the alternatives involving decreasing h, dk and dv .

4.3 Speed
Table 2 shows training and inference times for the various models. Both training and inference speeds were
evaluated on one TPUv2 (8 cores). A training step (consisting of 32,768 input tokens and 32,768 target
tokens, as described above) took 433ms for the base model and 425ms for the multi-query model. Dividing
by 32,768, we find that the training time is 13.2µs per (input-token + target-token), as listed in Table 2.
We ran incremental greedy inference on a batch of 1024 sequences (128 per core) using a source-sequence
length of 128 tokens and a target sequence length of 128. 1 For the baseline model, the encoder part of the
model took 222ms and each incremental step of the decoder took 47ms. Dividing by the respective numbers
of tokens, we find that the amortized inference time is 1.7µs per token for the encoder and a much larger
46µs per token for the decoder, as listed in Table 2. For the multi-query model, the encoder took 195ms
and the decoder took 3.9ms per step, for amortized per-token costs of 1.5µs and 3.8µs respectively. Table 2
shows these values as well as similar results for beam-search.
1 Due to system limitations requiring fixed shapes, we used padding and masking in our decoder-self-attention implementation.

The memory tensors were thus padded to the maximum length (128), or to the window-size (32) in the case of local attention.
Each decoding step thus took the same amount of time. An alternative implementation of incrementally growing the tensors
could save time near the beginning of the sequence.

7
Table 1: WMT14 EN-DE Results.

Attention h dk , dv df f ln(PPL) BLEU BLEU (test)


Type (dev) (dev) beam 1 / 4
multi-head 8 128 4096 1.424 26.7 27.7 / 28.4
multi-query 8 128 5440 1.439 26.5 27.5 / 28.5
multi-head local 8 128 4096 1.427 26.6 27.5 / 28.3
multi-query local 8 128 5440 1.437 26.5 27.6 / 28.2
multi-head 1 128 6784 1.518 25.8
multi-head 2 64 6784 1.480 26.2 26.8 / 27.9
multi-head 4 32 6784 1.488 26.1
multi-head 8 16 6784 1.513 25.8

Table 2: Amortized training and inference costs for WMT14 EN-DE Translation Task with sequence length
128. Values listed are in TPUv2-microseconds per output token.

Attention Training Inference Beam-4 Search


Type enc. + dec. enc. + dec.
multi-head 13.2 1.7 + 46 2.0 + 203
multi-query 13.0 1.5 + 3.8 1.6 + 32
multi-head local 13.2 1.7 + 23 1.9 + 47
multi-query local 13.0 1.5 + 3.3 1.6 + 16

Table 3: Billion-Word LM Benchmark Results.

Attention h dk , dv df f dev-PPL
multi-head 8 128 8192 29.9
multi-query 8 128 9088 30.2
multi-head 1 128 9984 31.2
multi-head 2 64 9984 31.1
multi-head 4 32 9984 31.0
multi-head 8 16 9984 30.9

5 Conclusion
We have proposed multi-query attention - an alternative to multi-head attention with much lower memory-
bandwidth requirements in the incremental setting. We believe that this enables wider adoption of attention-
based sequence models in inference-performance-critical applications.

References
Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to
align and translate, 2014.
Ciprian Chelba, Tomas Mikolov, Mike Schuster, Qi Ge, Thorsten Brants, and Phillipp Koehn. One billion
word benchmark for measuring progress in statistical language modeling. CoRR, abs/1312.3005, 2013.
URL http://arxiv.org/abs/1312.3005.
Peter J Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, Ryan Sepassi, Lukasz Kaiser, and Noam Shazeer.

8
Generating wikipedia by summarizing long sequences. In Proceedings of the International Conference on
Learning Representations, 2018.
Daniel Povey, Hossein Hadian, Pegah Ghahremani, Ke Li, and Sanjeev Khudanpur. A time-restricted self-
attention layer for ASR. In Proceddings of the IEEE International Conference on Acoustics, Speech and
Signal Processing (ICASSP). IEEE, 2018.
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser,
and Illia Polosukhin. Attention is all you need. In NIPS, 2017.
Biao Zhang, Deyi Xiong, and Jinsong Su. Accelerating neural transformer via an average attention network,
2018.

You might also like