Transformer 24 Aug
Transformer 24 Aug
Transformers
LLMs are built out of transformers
Transformer: a specific
Provided proper kind Google
attribution is provided, of network architecture,
hereby grants permission to like a
fancier feedforward network, but based on attention
reproduce the tables and figures in this paper solely for use in journalistic or
scholarly works.
Illia Polosukhin⇤ ‡
[email protected]
A very approximate timeline
1990 Static Word Embeddings
2003 Neural Language Model
2008 Multi-Task Learning
2015 Attention
2017 Transformer
2018 Contextual Word Embeddings and Pretraining
2019 Prompting
Attention
Transformers
Instead of starting with the big picture
Let's considerlong
Next token
the embeddings
and
for an
thanks
individual
for
word from
all
a particular layer
Next token long and thanks for all
Language
Modeling
Language
logits logits logits logits logits …
Head
Modeling
logits U logits U logits U logits U logits U …
Head U U U U U
Stacked
…
… …
… …
… …
… ……
Stacked
Transformer
Transformer ……
Blocks
x1 x2 x3 x4 x5 ……
x1 x2 x3 x4 x5
+ 1 + 2 + 3 + 4 + 5
Input
Input
Encoding
+ 1 + 2 + 3 + 4 + 5 …
Encoding
E
E
E
E
E
E
E
E
E
E
…
Input tokens So long and thanks for
Input tokens So long and thanks for
Problem with static embeddings (word2vec)
They are static! The embedding for a word doesn't reflect how its
meaning changes in context.
The chicken didn't cross the road because it was too tired
At this point in the sentence, it's probably referring to either the chicken or the street
Intuition of attention
chicken
because
didn’t
cross
tired
Layer k+1
road
The
the
was
too
it
self-attention distribution
chicken
because
didn’t
cross
tired
Layer k
road
The
the
was
too
it
Attention definition
a1 a2 a3 a4 a5
x1 x2 x3 x4 x5
re words to other words? Since our representations for
Simplified version
Verson 1: of attention:
score(x i , x j ) a=sum
xi · of
x j prior words (1
make use of our old friend the dot product
weighted by their similarity with the current word that we used
larity
esult ofinGiven
a Chapter
dot product6, isand
a also value
scalar played a rolefrom
ranging in attention
• to •, in la
the
a sequence of token embeddings:
the resultsimilar
the more ofx this
the comparison
vectors that between
are being words
compared. i and j as
Continuing a
with
1 x2 x3 x4 x5 x6 x7 xi
this equation
the first step in to add attention
computing y3 would tobethe to computation of thisx3
compute three scores:
d x3 · x3 .Produce: ai = aeffective
Then to make weighted sum
use of x1 scores,
of these throughwe’ll
x7 (and xi)
normalize t
oftmax Weighted by their similarity to
a x
ERS AND L ARGE L ANGUAGE M ODELSi j ,i that indicates the proporti
to create a vector of weights,
of each input to the input
erson 1: score(x , x ) = x · x element i that is the current focus of attentio
(10.4)
i j i j
, each weighted by its a value.
ai j = softmax(score(xi , x j )) 8 j i (1
oduct is a scalar X value ranging from • to •, the larger
a = exp(score(x
a x i , x j )) (10.7)
the vectors that= are
i Pibeing compared. Continuing
i j j 8j i with our (1
ji exp(score(xi , xk ))
Intuition of attention:
columns corresponding to input tokens
chicken
because
test
didn’t
cross
tired
Layer k+1
road
The
the
was
too
it
self-attention distribution
chicken
because
didn’t
cross
tired
Layer k
road
The
the
was
too
it
x1 x2 x3 x4 x5 x6 x7 xi
An Actual Attention Head: slightly more complicated
chicken
because
didn’t
cross
tired
Layer k+1
road
The
the
was
too
it
self-attention distribution
chicken
because
didn’t
cross
tired
Layer k
road
The
the
was
too
it
x1 x2 x3 x4 x5 x6 x7 xi
values
Intuition of attention: columns corresponding to input tokens
query
chicken
because
didn’t
cross
tired
Layer k+1
road
The
the
was
too
it
self-attention distribution
chicken
because
didn’t
cross
tired
Layer k
road
The
the
was
too
it
x1 x2 x3 x4 x5 x6 x7 xi
keys k k k k k k k k
v
values v v v v v v v
mine a similarity weight. We’ll refer to this role as a k
An Actual Attention Head: slightly more complicated
ally, as a value of a preceding element that gets weigh
ompute the
We'll useoutput
matricesfor the current
to project element.
each vector xi into a
representation of its role as query, key, value:
e these threeWdifferent
• query: Q roles, transformers introduce
d W . These weights will project each input vector xi
V• key: WK
Q V
qi = xi W ; k i = xi W ;
K
vi = xi W
n theseToprojections,
compute similarity
when we of current element the
are computing xi with
similarity of t
some prior element
xi with some prior element x j x , we’ll use the dot product betw
j
We’ll use
ent’s query dot q
vector product
i and thebetween
precedingq i and k .
element’s
j key vector k j
esult ofAnd instead
a dot of summing
product can be anup xj , we'll large
arbitrarily sum up vj
(positive or nega
nentiating large values can lead to numerical issues and loss of g
ing. To avoid this, we scale the dot product by a factor related to
i i
by summing the values of the prior elements, each weig
Final equations for one attention head
its key to the query from the current element:
Q K V
qi = xi W ; k j = x j W ; v j = x j W
qi · k j
score(xi , x j ) = p
dk
ai j = softmax(score(xi , x j )) 8 j i
X
ai = ai j v j
ji
Calculating the value of a3
Output of self-attention a3
×
×
4. Turn into 𝛼i,j weights via softmax
÷ ÷ ÷
3. Divide score by √dk √dk √dk √dk
Wk
k Wk
k Wk
k
1. Generate Wq
q Wq
q Wq
q
key, query, value
vectors Wv Wv Wv
x1 v
x2 v
x3 v
Actual Attention: slightly more complicated
• Instead of one attention head, we'll have 9.2
lots of• them!
T RANSFORMER B LOCKS 7
• Intuition: each head might be attending to the context for different purposes
shows an intuition.
• Different linguistic relationships or patterns in the context
ai
[1 x d]
Project down to d WO [hdv x d]
… [1 x hdv ]
Concatenate Outputs
[1 x dv ] [1 x dv ]
Each head Head 1 Head 2 Head 8
attends differently …
WK1 WV1 WQ1 WK2 WV2 WQ2 WK8 WV8 WQ8
to context
Transformers
The Transformer Block
Transformers
Reminder: transformer language model
Next token long and thanks for all
Language
Modeling
logits logits logits logits logits …
Head U U U U U
Stacked
… … … … …
Transformer …
Blocks
x1 x2 x3 x4 x5 …
+ 1 + 2 + 3 + 4 + 5
Input
Encoding E E E E E
…
Feedforward
Layer Norm
… …
+
MultiHead
Attention
Layer Norm
xi-1 xi xi+1
network be larger than the model dimensionality d. (For example in the orig
ransformer model, d = 512 and dff = 2048.)
We'll need nonlinearities, so a feedforward layer
FFN(xi ) = ReLU(xi W1 + b1 )W2 + b2 (9
hi-1 hi hi+1
Layer Norm At two stages in the transformer block we normalize the vector
et al., 2016). This process, called layer
+
norm (short for layer normalization), is
Feedforward
Layer Norm
… …
+
MultiHead
Attention
Layer Norm
xi-1 xi xi+1
Layer norm: the vector xi is normalized twice
hi-1 hi hi+1
Feedforward
Layer Norm
… …
+
MultiHead
Attention
Layer Norm
xi-1 xi xi+1
token. Thus the input to layer norm v is a single vector of dimensionality d
output is that vector normalized, X d u
again of d dimensionality d. The first step in
Layer 1 u1 X
ormalization is to Norm
µ
calculate
= sthe= xi t µ, and(xstandard
mean, i µ) 2 deviation, s , over the
(9.21) (9.22)
ts of the vector to be normalized. d d an embedding vector x of dimen-
Given
i=1 i=1
ty d, these values are v
calculated as from
follows.
u
Layer norm is a variation of
d
the z-score statistics, applied to a single vec- tor in a hidden layer
n these values, the vector u 1 X
components are normalized by subtracting the mean
s = t X d (x µ) 2 (9.22)
each and dividing byµ the d1
standard i deviation. The result of this computation is
= i=1 xi (9.21)
w vector with zero mean andd ai=1standard deviation of one.
v
values, the vector components u ared (x normalized
µ) by subtracting the mean
u x̂1 =
X (9.23)
d dividing by the standard
s = deviation.
t The
(xis µ) result
2 of this computation is
(9.22)
with zero mean and a standardd deviation
i=1 of one.
lly, in the standard implementation of layer normalization, two learnable param-
these values,
, g and the vector components
b , representing (x and
gain µ)offset
are normalized
values, byintroduced.
are subtracting the mean
x̂ = (9.23)
ach and dividing by the standard s deviation. The result of this computation is
vector with zero mean and a standard deviation (x ofµ)
one.
LayerNorm(x)
e standard implementation =g
of layer normalization, b learnable param- (9.24)
+two
(x µ) s
, representing gain and offset values, are introduced.
Putting it all together The function computed by a transforme
Puttingpressed
together a single
by breaking transformer
it down block
with one equation for each compo
using t (of shape [1 ⇥ d]) to stand for transformer and supersc
each computation inside the block:
hi-1 hi hi+1
1
+
ti = LayerNorm(xi )
⇥ ⇤
Feedforward
t2i = MultiHeadAttention(t1i , x11 , · · · , x1N )
3 2
…
Layer Norm
…
ti = ti + xi
+ 4 3
MultiHead
ti = LayerNorm(ti )
Attention
Layer Norm
t5i = FFN(t4i )
hi = t5i + t3i
xi-1 xi xi+1
Notice that the only component that takes as input information
(other residual streams) is multi-head attention, which (as we see
A transformer is a stack of these blocks
so all the vectors are of the same dimensionality d
hi-1 hi hi+1
Feedforward
Layer Norm
… …
Block 2 +
MultiHead
Attention
Layer Norm
xi-1 xi xi+1
hi-1 hi hi+1
Feedforward
Layer Norm
… …
Block 1 +
MultiHead
Attention
Layer Norm
xi-1 xi xi+1
Residual streams and attention
Notice that all parts of the transformer block apply to 1 residual stream (1
token).
Except attention, which takes information from other tokens
Elhage et al. (2021) show that we can view attention heads as literally moving
information from the residual stream of a neighboring token into the current
stream .
Token A Token B
residual residual
stream stream
The Transformer Block
Transformers
Parallelizing Attention
Computation
Transformers
Parallelizing computation using X
9.3 • PARALLELIZING COMPUTATION USING A SINGLE MATRIX X 11
ven these matrices we can compute all the requisite query-key comparisons simul-
eously by multiplying Q and K| in a single matrix multiplication. The product is
QKT
N
The N N | matrix showing how it computes all q · k comparisons
Parallelizing attention
⇥ QK i j
rix multiple.
• Scale the
| scores,wetake
we have this QK matrix, can the
verysoftmax,
efficientlyand
scalethen
these scores,
multiply
ax, and then thethe
multiply result
resultby
byVVresulting
resulting in
in aa matrix
matrixofofshape N
mbeddingshape N × d for each token in the input. We’ve reduced
representation
• An
f-attention attention
step vectorsequence
for an entire for each of
input token for one head to
N tokens
computation:
✓ ✓ | ◆◆
QK
A = softmax mask p V (9
dk
out the future You may have noticed that we introduced a mask func
self-attention step for an entire sequence of N tokens for one head
Masking out the future
ing computation:
✓ ✓ | ◆◆
QK
A = softmax mask p V
dk
What
ng out•the is thisYou
future mask function?
may have noticed that we introduced a mask f
QK This
9.32 above. T has aisscore forthe
because each query dot every
self-attention key, as we’ve de
computation
including those that follow | the query.
a problem: the calculation in QK results in a score for each quer
Guessing
y key •value, the next
including word
those thatisfollow
prettythe
simple if you
query. This is inapprop
already know
ting of language it!
modeling: guessing the next word is pretty simple
y know it! To fix this, the elements in the upper-triangular portion
are zeroed out (set to •), thus eliminating any knowledge of wo
in the sequence. This is done in practice by adding a mask matri
vector embedding representation for each token in the input. We’ve reduced the
tire self-attention step for an entire sequence of N tokens for one head to the
Masking out the future
llowing computation:
✓ ✓ | ◆◆
QK
A = softmax mask p V (9.32)
dk
asking out the future You may have noticed that we introduced a mask function
Eq. 9.32 above. This is because the self-attention computation as we’ve described
q1•k1 −∞ −∞ −∞
Add –∞ the
has a problem: to cells in upper
calculation in QKtriangle
|
results in a score for each query value
everyThe
key softmax
value, including those that
will turn it to 0 follow the query. This is
q2•k1 inappropriate
q2•k2 −∞ −∞ in
e setting of language modeling: guessing the next word N is pretty simple if you
ready know it! To fix this, the elements in the upper-triangular portion
q3•k1 q3•k2 −∞the
q3•k3 of
atrix are zeroed out (set to •), thus eliminating any knowledge of words that
q4•k1 q4•k2 q4•k3 q4•k4
llow in the sequence. This is done in practice by adding a mask matrix M in
hich Mi j = • 8 j > i (i.e. for the upper-triangular portion) and Mi j = 0 otherwise.
|
g. 9.9 shows the resulting masked QK matrix. (we’ll see in Chapter N 11 how to
vector embedding representation for each token in the input. We’ve reduced the
tire self-attention step for an entire sequence of N tokens for one head to the
Another point: Attention is quadratic in length
llowing computation:
✓ ✓ | ◆◆
QK
A = softmax mask p V (9.32)
dk
asking out the future You may have noticed that we introduced a mask function
Eq. 9.32 above. This is because the self-attention computation as we’ve described
q1•k1 |−∞ −∞ −∞
has a problem: the calculation in QK results in a score for each query value
every key value, including those q2•k1
that follow the query.
q2•k2 −∞ −∞ This is inappropriate in
e setting of language modeling:
N guessing the next word is pretty simple if you
ready know it! To fix this, the elements inq3•k3
q3•k1 q3•k2 the upper-triangular
−∞ portion of the
atrix are zeroed out (set to •), thus eliminating any knowledge of words that
q4•k1 q4•k2 q4•k3 q4•k4
llow in the sequence. This is done in practice by adding a mask matrix M in
hich Mi j = • 8 j > i (i.e. for the upper-triangular portion) and Mi j = 0 otherwise.
| N
g. 9.9 shows the resulting masked QK matrix. (we’ll see in Chapter 11 how to
Q Q X X K K X X V V
Q X
Query Query
Q
Attention
X
Qagain
W K WK Input Input
Token 1Token 1
Key Key
X
K 1Token 1
Token
K
Input Input
XToken 1Token 1
V
WV W X
V
Value Value
Token 1Token 1
V
Token 1Token 1 X Q K
X Q X KV K X V V
Input Q
Input
X
Input K Q
Input
Key X Key XK
W Input Key
Input X Value
Input W
X Value V V
Q Query Query
Token
Query 1 W Input Query
W Key
Token
Input
K Token W
1 Value V
Value
Input Token 2Token
1 QQ 2
Token 1 Token11 2Token 2 KW1 WK Key2 Key Key
Token 2Token Token 1 V W V Token 1
=
2Token
TokenToken
= 1 2 Input
Input Token
Token 1 1
Token
W
Input
W
Token 1 x x
Q
Query
Query
W = Query=
Token
Input
Input
Input
Token
Input Token
1 1 Token 1
W Key x
TokenToken x
1 1
1 Token
Token
= 1Token
Input
=
Input Input
Input Token
2
1
Token
W W2
Value Value
Value
1Token
1 1
Query Query
Token
Query 2 InputInput Query Token 1 1
Input Token 1 TokenKey
Key
Key
Token Input
Input Input
Token 2 Key
Token 1 Token 1
Value 2Value Value
Value
TokenToken
Token 2 Input Input Token
Token
Token
Token 3Input2 3
Input
Input
x Token
Token 3 Input
Token
2 Token
3
= x
2 Query
=
Token
Token2Input
Query x
3Token 3 Token =2 3Token
Token Key3 Key Input
Token Input
x
InputToken 3
2Token 3 =
Token
Value
2Value
Value
= Token
Token 2 2 Token 2 Query
Input
2 TokenToken
2 2 Token 2
=x =
TokenToken
Key =
2 2
2 Token = Token
Token22
2Token
Token 2 2
x x
Token
2 2= Token
x Input
Token 2
Token
QueryQuery 3
Query InputInput Input xx
Input Input
Input x ==
Query Token Key
Key3 Key
Token Input Input
Input x x
Input3 Key Key
Input Token = Value
= Token Value 3Value
Token
3Value
x
Input Value
Input
=Token
Value
Token 3 3 Token
Token
Token 4
3 Query
Query
Query
Token
Token 3Input
4Token 4 Token 3Token 3 4Token Key
4 3 Token Value Input
Token 3Token
dk
Token
Token 4Input3 Token
Token 4 Token 3 3 Token 4
InputInput
d x ddk x dk
Token
Query Token 3 3
Token 3 Token
Input
Key
Token
3 3 Token
d x ddv x dv
Input Input
Key Token
Token 3 3
Token
Token 3
Input Token
Value3Token 3
Token 3 4Token
Token3 3 4
Token
Query 4 Input Token 4Input
Query 4 Key Key
TokenKey Value
Token Input
Value 4 Input
Input
4 4 Token 4 Query
Token
Token
d x
Query
N x dN x d
Token
4 4 Token 4 d
Input Token x 4d
Token
k 4 d x dToken
Token
Value
Token 4 Input Token 4 d x d
4Value
4d4 x d dd4k 4 dvx d
Token 4
Token 4 Token
dk N x dNk x dk
Token
k N x
N xdddx xd d d dToken
xToken k
Token
k d
d xN x d
dkx d
Token 4
N x d v N
Token 4x dN d
kx
d x
d
x 4
d
d d
Token
x d4
v Token
Token 4 Token 4
4Token 4
Nxd N x dkNkx d
k
NNxNd
k xd Nxd
N v vv v N xN dx d
kN x xdd Nxd k k N xNdkx dk NNxNxdxvdNNxd
N x dk NNx xd d Nxd N x dxkdk k xNdx d
k N Nxd dxd N x dNvvxvdv
T
KT K Q
QKTT QK T
K T KT T
KT
QK T
T QK T masked
masked
QK T
TV QKVT maskedA A V A
KT Q QQ QK K QK QKT T
QKT masked
QK
QK
V QKTmasked
T masked
A
masked
QK
V
V V A
A
A
q1
x = q1•k1 q1•k2 q1•k3 q1•k4 q1•k1 −∞ −∞ −∞ v1 a1
k1
k2
k3
k4
q1•k1
x q1= xq1•k2q2q1•k3
= xq1q1•k1 −∞ −∞ −∞ −∞ v1 −∞v1 −∞ a1
k1
k2
k3
k4
k3k2
k4
q1 q1•k3
= =q1•k1
q1•k1
q1•k1 q1•k1
q1•k1 −∞ v1 a1
k1
k2
k3
k4
k2
k3
k4
k2
k3
k4
q1•k4q2•k4 v1 =
k1
k3
k4
mask mask
q2
mask q2q2•k1
q2 q2•k1q3q2•k3
q2•k2 q2•k2 q2•k4 =q2•k1
q2•k3 q2•k4 =q2•k2
q2•k1
q2•k1 q2•k1
q2•k2 −∞
q2•k2
q2•k3q2•k3
q2•k2 q2•k4 −∞
−∞
q2•k4 =x q3•k4
=−∞
q2•k1 x
v2 −∞
= q2•k2
q2•k1 =q2•k2
q2•k2
x
v2 −∞
−∞ −∞ −∞ =
xa2
x = xa2
v2
v2 = =a2a2a3 a2
=
q2•k1 q2•k2 q2•k3 q2•k4 = q2•k1 q2•k2
q3•k1
q2•k1 q2•k2 q2•k3
q3•k2 q2•k4
q3•k3
−∞ −∞ q2•k1
q3•k1
v2 q3•k2−∞
q3•k3 −∞a2 v2
v3
q3
q3•k3 d xN
q3q3 q4
q3•k1 q3•k1 q3•k3
q3•k2 q3•k2 q3•k4 q3•k4
k q3•k1 q3•k2
q3•k1
q3•k1 q3•k3
q3•k2
q3•k1
q3•k2
q3•k1 q3•k4
q3•k3
q4•k1 q3•k4
q4•k2
q3•k2
q3•k3
q3•k2 −∞
q4•k3 −∞
q3•k1
q3•k4q4•k4
q3•k3
q3•k3 q3•k2
q3•k1
v3 q3•k1
q3•k3
v3q4•k1
q3•k2 q3•k3
q4•k2
q3•k2 −∞
−∞ −∞
v3 q3•k3
q4•k3 a3
q4•k4 v3
v3
a3 v4
v3 a3a4 a3
a3
q3•k1 q3•k2 q3•k3 q3•k4 q3•k1 q3•k2 q3•k3 −∞ a3
q4 q4 dN xx dNdkdx NxN v4 a4
dk xddNkx xN N q4
q4•k1 k
q4•k1
q4•k1q4•k2 k kq4•k4
q4•k2
q4•k2q4•k3
q4•k3 q4•k3
q4•k4 q4•k4 q4•k1
q4•k1 q4•k2 q4•k3 q4•k4
q4•k1 q4•k2 q4•k3
q4•k1
q4•k1
q4•k1q4•k2
q4•k2
q4•k4
q4•k2 N
q4•k2
q4•k3 x
q4•k3 N
q4•k4
q4•k3
q4•k4
q4•k1
q4•k1
q4•k4v4
q4•k2
q4•k2
q4•k1
v4
q4•k3q4•k4
q4•k3
v4
q4•k2N
q4•k4
x N q4•k4
q4•k3
a4 a4
v4
a4N xv4dv Na4x da4
v
k q4•k3 q4•k4
utting it all together with the parallel input matrix X The function compu
n parallel by an entire layer of N transformer block over the entire N input tok
an be expressed as:
e expressed as:
Putting it all together with the parallel input matrix X The function computed
Parallelizing
in parallel by an entire Multi-head
layer of N Attention
transformer block over the entire N(9.36)
O = LayerNorm(X + MultiHeadAttention(X)) input tokens
can be expressed as:
H = LayerNorm(O + FFN(O)) (9.37)
O = LayerNorm(X + MultiHeadAttention(X))
e can break it down with one equation for each component computation, using
(9.36)
H = for
shape [N ⇥ d]) to stand LayerNorm(O
transformer +
and FFN(O))
superscripts to demarcate each (9.37)
utation inside the block:
Or we can or
break it down with one equation for each component computation, using
T (of shape [N ⇥Td])1 to stand for transformer and superscripts to demarcate each
= MultiHeadAttention(X) (9.38)
computation inside
T2the=block:
X + T1 (9.39)
3 2
T = LayerNorm(T
1 )
T = MultiHeadAttention(X) (9.40) (9.38)
4 3
T = FFN(T 2
T = X+T ) 1 (9.41) (9.39)
T5 = T 43
+ T 3
T = LayerNorm(T2 ) (9.42) (9.40)
5
H = LayerNorm(T
T = FFN(T)3 )
4 (9.43) (9.41)
Parallelizing Attention
Computation
Transformers
Input and output: Position
embeddings and the Language
Model Head
Transformers
Token and Position Embeddings
Transformer Block
X = Composite
Embeddings
(word + position)
+
+
+
Word
Janet
back
will
the
bill
Embeddings
Position
1
5
Embeddings
Janet will back the bill
Language modeling head
Unembedding layer: linear layer projects from hLN (shape [1 × d]) to logit vector
…
be good at doing
u1 u2
bothu|V|
of these mappings. We
Logits 1 x |V|
therefore
Softmax turnssometimes call the t
the logits into
Unembedding
E the unembedding layer because it is performing
T layer = E T
Unembedding layer d x |V|
probabilities over
this vocabulary.
reverse mapping.
A softmax layer turns the logits u into the
hLN 1xd Shape 1 × |V |. y over the voca
probabilities
… u = L
hN E T
wN
y = softmax(u)
y1 y2 … y|V|
The final transformer
Token probabilities wi+1
Sample token to
softmax
model
Language generate at position i+1
Modeling
Head logits u1 u2 … u|V|
hLi
feedforward
…
layer norm
Layer L
y|V|
attention
x1i
hLi Input
+ i
Encoding E
feedforward
Input token wi
Input and output: Position
embeddings and the Language
Model Head
Transformers