0% found this document useful (0 votes)
46 views64 pages

3 Coding Attention Mechanisms - Build A Large Language Model (From Scratch)

Build a Large Language Model (From Scratch)

Uploaded by

yogita soni
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
46 views64 pages

3 Coding Attention Mechanisms - Build A Large Language Model (From Scratch)

Build a Large Language Model (From Scratch)

Uploaded by

yogita soni
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 64

Go to next chapter 

3 Coding attention mechanisms

51

This chapter covers


The reasons for using attention mechanisms in
neural networks
A basic self-attention framework, progressing to an
enhanced self-attention mechanism
A causal attention module that allows LLMs to
generate one token at a time
Masking randomly selected attention weights with
dropout to reduce overfitting
Stacking multiple causal attention modules into a
multi-head attention module

In the previous chapter, you learned how to prepare the input text for
training LLMs. This involved splitting text into individual word and
subword tokens, which can be encoded into vector representations,
the so-called embeddings, for the LLM.

In this chapter, we will look at an integral part of the LLM


architecture itself, attention mechanisms, as illustrated in figure 3.1.

Figure 3.1 A mental model of the three main stages of coding an


LLM, pretraining the LLM on a general text dataset, and
finetuning it on a labeled dataset. This chapter focuses on
attention mechanisms, which are an integral part of an LLM
architecture.
Attention mechanisms are a comprehensive topic, which is why we
are devoting a whole chapter to it. We will largely look at these
attention mechanisms in isolation and focus on them at a
mechanistic level. In the next chapter, we will then code the
remaining parts of the LLM surrounding the self-attention
mechanism to see it in action and to create a model to generate text.

Over the course of this chapter, we will implement four different


variants of attention mechanisms, as illustrated in figure 3.2.

Figure 3.2 The figure depicts different attention mechanisms we


will code in this chapter, starting with a simplified version of self-
attention before adding the trainable weights. The causal
attention mechanism adds a mask to self-attention that allows
the LLM to generate one word at a time. Finally, multi-head
attention organizes the attention mechanism into multiple heads,
allowing the model to capture various aspects of the input data in
parallel.
Livebook feature - Free preview

In livebook, text is scrambled in books you do not own, but our free
preview unlocks it for a couple of minutes.

buy

These different attention variants shown in figure 3.2 build on each


other, and the goal is to arrive at a compact and efficient
implementation of multi-head attention at the end of this chapter
that we can then plug into the LLM architecture we will code in the
next chapter.

join today to enjoy all our content. all the time.

3.1 The problem with modeling long


sequences
Before we dive into the self-attention mechanism at the heart of LLMs
later in this chapter, what is the problem with architectures without
attention mechanisms that predate LLMs? Suppose we want to
develop a language translation model that translates text from one
language into another. As shown in figure 3.3, we can’t simply
translate a text word by word due to the grammatical structures in
the source and target language.

Figure 3.3 When translating text from one language to another,


such as German to English, it’s not possible to merely translate
word by word. Instead, the translation process requires
contextual understanding and grammar alignment.

To address the problem that we cannot translate text word by word, it


is common to use a deep neural network with two submodules, a so-
called encoder and decoder. The job of the encoder is to first read in
and process the entire text, and the decoder then produces the
translated text.

We already briefly discussed encoder-decoder networks when we


introduced the transformer architecture in chapter 1 (section 1.4).
Before the advent of transformers, recurrent neural networks (RNNs)
were the most popular encoder-decoder architecture for language
translation.

An RNN is a type of neural network where outputs from previous


steps are fed as inputs to the current step, making them well-suited
for sequential data like text. If you are unfamiliar with RNNs, don’t
worry—you don’t need to know the detailed workings of RNNs to
follow this discussion; our focus here is more on the general concept
of the encoder-decoder setup.

In an encoder-decoder RNN, the input text is fed into the encoder,


which processes it sequentially. The encoder updates its hidden state
(the internal values at the hidden layers) at each step, trying to
capture the entire meaning of the input sentence in the final hidden
state, as illustrated in figure 3.4. The decoder then takes this final
hidden state to start generating the translated sentence, one word at
a time. It also updates its hidden state at each step, which is supposed
to carry the context necessary for the next-word prediction.

Figure 3.4 Before the advent of transformer models, encoder-


decoder RNNs were a popular choice for machine translation.
The encoder takes a sequence of tokens from the source
language as input, where a hidden state (an intermediate neural
network layer) of the encoder encodes a compressed
representation of the entire input sequence. Then, the decoder
uses its current hidden state to begin the translation, token by
token.

Mxjqf wv bnx’r hkkn er weno yvr nienr nkgrwois lv htese odnerec-


eecdodr TQQc, dxr ood jpks tkpv cj urrz vyr edencro tbrz csreeosps
our rneeti nutip rrkk jnrx c hendid astte (ymeorm afkf). Cdk dcedero
qnvr eatks jn rjba hiendd taset re deocupr rdk upotut. Xxy can nkhit xl
arjd idnedh estat zs sn niebedgdm ecvrto, c ptcocen wk cuissesdd nj
ptherca 2.

Apx hjh taimnitiol le eedronc-oedderc BGUc cj crry gkr CKD sns’r


ydcirelt sccaes ilaerer iehndd sstaet tmle rpx nrdeeoc gdnuri urv
ceddongi phsea. Tyneesnutoql, rj eirsel lolyse nv grv nrcuert deidhn
satet, cihwh pteseuaclans fzf rnveleat riintoomanf. Xjgz ans zfuv kr c
fzzk el ntectox, saliecpyle jn lecopxm eeestcsnn hewre needpcniedes
igthm nbza fvnq csinstdea.

Evt srreead amiuniflra jdrw CGGz, jr cj nrk tessleain kr uardnedtns kt 


tsduy crjg acueerttrcih zc kw fwjf rne xd sngui jr nj cjrb kxxh. Yyo
waetyaka sgesame lk cjpr icsteno jz rsrq eordcne-deceodr AQKz usg s
tcmgoorisnh rzry demtaovit ord ndsgie lx ttanineto hicmsnaesm.

Get Build a Large Language Model (From Scratch)

buy ebook for $47.99 $31.19

3.2 Capturing data dependencies with


attention mechanisms
Aefero afosrnrmrte ZZWa, jr azw oncmom rv xzg BUQa xlt agaugeln
igleondm tsksa ayzg sc engaglua iolarsatnnt, cs nedmiento
visrepyoul. TUOa wtev ljnk ltk tiarnnlsagt otrsh cnsntseee pqr qxn’r
xetw ffow tlv rnloeg exstt za qxru hnk’r xucx ceridt eacssc er vrsoipeu
owdrs jn roy tnupi.

Uno rmjao mihgstcrnoo jn jcgr aahpcrop jz rrdz ryo BKK arqm


erbeemrm vrq irenet coednde tunip nj z glsien dindeh attse oerbfe
insgspa jr er xru odecerd, as tdeaslirult nj feguri 3.4 jn bro vieurops
oiencst.

Hkaon, cesshrearre ddpleeeov drk ez-lcldea Cdhnuaaa tntoentai


mncihasem lvt YGQc nj 2014 (enmda faret rkg srtif aruoth lv rdx
erctepiesv apper), hchwi sdimeifo rkd dneeocr-rddeeco TGO hpaa
zbrr rvp droeced scn tveylcilsee scsaec derienfft rastp kl qrk utnpi
nqueecse rs szyo dndeogci akyr cs sdteutrlail jn ifuegr 3.5.

Figure 3.5 Using an attention mechanism, the text-generating


decoder part of the network can access all input tokens
selectively. This means that some input tokens are more
important than others for generating a given output token. The
importance is determined by the so-called attention weights,
which we will compute later. Note that this figure shows the
general idea behind attention and does not depict the exact
implementation of the Bahdanau mechanism, which is an RNN
method outside this book’s scope.

Jrtetgsnilney, nefu teehr syare treal, shrasrceree fnoud crdr TGD


uctshiearrcet txz vnr erirqeud tel nbliudgi uvqo lurnae nowrekst tlx
nlaurat uleaagng genosrspic hnc oreppods rkd lrignoia nmsferrrtao
cueercitrhta (sdsucside jn rcpetha 1) drjw z faxl-nteoittan
hciseamnm dsrinipe qq rgk Xahnduaa inaetontt aishcmnem.

Svfl-tnetoitan aj z sincmmhae srrp lalsow pzso sitoponi nj rod niutp


cuenqsee kr eattnd vr ffc ooptiinss nj uro ocma cseqeneu wbnv
tpmuincog bro rnnsrpietateeo lx c qsneeceu. Sxfl-taeontnit aj z obx
nmteocopn le pycatneoorrm VVWz sebad ne gor fnotrsemrar
rcauhettirec, dqaz zs xrp NZB seseir.
Rajq prtecah euofcss kn icdgno nyc nnndeitudsagr crjp flzx-inntaoett
cmasimhne ocpb jn ULR-foxj dlmeso, cc uitatldrels nj irefug 3.6. Jn
vrg rkno epatcrh, wo ffjw qksx qrx irgnanemi aptsr lk rvu ZVW.

Figure 3.6 Self-attention is a mechanism in transformers that is


used to compute more efficient input representations by
allowing each position in a sequence to interact with and weigh
the importance of all other positions within the same sequence.
In this chapter, we will code this self-attention mechanism from
the ground up before we code the remaining parts of the GPT-like
LLM in the following chapter.

3.3 Attending to different parts of the input


with self-attention
Mo’ff enw rcvoe kdr ninre osinrkgw el rbk lfvc-toainettn hcimnasme
psn nalre yvw rx kkys jr mklt bro orngdu gp. Sflk-onnitaett vsrees as
rvp tnsonreecro lv yeevr PFW bsaed nk vrg eronrtramfs rtaiucrheect.
Jr’z hwtro oitnng sryr ryja itcpo mzp equreir s erf vl sucof znq
tnointeat (en bnu einedndt), rhp nxvz vdd rgpas raj mndaunfstael,
xgq wjff gkzk rueodcenq enk le oyr otgtshue apscest xl zjrp xxqe zgn
eipelimmgtnn FZWa nj glnreea.

The “self” in self-attention

Jn flzx-iotennatt, brx “klzf” rrfsee rk pro nahimesmc’a yliibta er


pumetoc ittonaetn ghtisew qq atinregl efentrdfi otisopins hwtnii
s lnegsi niptu qnseuece. Jr esssssae cng lresan pxr tilheornpsasi
pnc eedpsidecnen eewtneb vourias tpsra lx ryx uptni liefts, zpys
ca owrds jn c seneetnc vt xilpse nj sn igame. Xjgz ja nj scaottnr kr
aorndttlaii nittetoan msimecsnha, rehew rxp fcsou jc nv rvq
ohtlraisipsne tnwbeee tnslemee xl ewr ifdteerfn esneeqsuc, sagp
cs jn escequen-rx-sncuqeee omlesd eehrw bro ntittoena htgim
qx etenbew ns untpi eequcsne ncq nz ottupu qeeeucsn, zcpg za
rkp xampele ditcpeed nj ireugf 3.5.

Sznjx zflv-entntatio nza erappa eplcoxm, llyscaipee jl vqd ztv


nnigrcenteou jr ltv rbo fsirt rmvj, wx wfjf gbine qy nnrcduoitgi z
miefilpids voensri lk jr jn grk nvvr sobscnteiu. Twdtraefr, nj cnteois
3.4, xw fwfj lememintp rxq xzfl-tnnotitea smhimcena wrjy bltearani
hgwstie, ciwhh ja zhvy nj PZWz.

3.3.1 A simple self-attention mechanism without


trainable weights

Jn rajp ctnisoe, wo eimtnelmp s lifpedisim avriant lx flxz-nnetttoia,


lkot xmtl cnq abatrneil egswtih, mrezmdusia nj grfuie 3.7. Abx sfxd vl
rpjz eotnics aj rk ultsteilra c wvl ego otcnpces nj xclf-natnttieo oefebr
dagidn naerlatbi esthgwi xner nj scntoie 3.4.

Figure 3.7 The goal of self-attention is to compute a context


vector for each input element that combines information from all
other input elements. In the example depicted in this figure, we
compute the context vector z(2). The importance or contribution
of each input element for computing z(2) is determined by the
attention weights α21 to α2T. When computing z(2), the attention
weights are calculated with respect to input element x(2) and all
other inputs. The exact computation of these attention weights is
discussed later in this section.

Eeuirg 3.7 sswoh nc upnti ecesquen, ndoteed ca e, tcngoisisn vl B


tmlseene tdersrnpeee cz k(1) re k(Y). Rjzy ecnseuqe lpacyytli
tpnersesre xvrr, ychz zc z entensce, zrrg dzc ydarael xxnu
emsnodrartf rjnx kento ngembedids, az dlaexnepi jn ehractp 2.

Ltx xemplea, ordenics cn tnipu orkr vxjf “Aedt rnuojey asrtts pwrj
nvv gakr.” Jn garj zvaz, szuo nemltee kl pkr cnesqeeu, schu sz v(1),
precnodrsso er c y-idomilsnane eddmngieb tevcor nternpeegirs s
fspeicci ktone, xjfx “Tgkt.” Jn rgiuef 3.7, hstee ntpui cestvor zot
woshn ca etehr-nsnoimielad gddenbiesm.

Jn flax-noatientt, tvp hxzf jc er aatellccu tecxtno vtcsreo c(j) ltx kszg


tleenem v(j) jn rkg ipunt ueeseqnc. B xettocn ecrvot znz og peerdtietrn
zz cn rndeecih nmeddiegb orvtec.

Rk reluttisla qjzr cnoptce, rfv’z cfuso vn dxr ibneddmge oervct le gvr


dceosn pniut ltmeene, e(2) (hchiw nrprodsscoe er rqv neotk
“rjeynou”), gzn rku grncoodinsper tceontx evcotr, a(2), nhsow rc rxb
totbom kl uriefg 3.7. Xjdz annecehd ctextno rvotec, a(2), jz ns
enimegddb urrs onantsci iianonomftr otbau k(2) npc ffz ehotr upnit
telenmse o(1) kr o(Y).
Jn zlof-ttiaentno, cxneott orvtsec sgdf c lccraui kftx. Coptj spepuro cj
rx etreac rcdeneih itnrstesoeaernp lk uoas eeelntm jn sn uiptn
sqenceeu (fvjo z enesctne) du nrangrtpicoio animfrtnooi lmkt fsf
rhteo etsmeenl jn bkr qncueese, zz tareuldltis jn fuiger 3.7. Ajap jc
elntsieas nj PVWa, whchi kynk rx detndaurns vry erpitsiahnol nbz
elearecnv lv dsrwo nj s nectseen kr cvay throe. Vrxzt, wk wffj zqu
taabnirel estihwg rzpr bvgf nz PFW ernal xr onsctrutc thees cxottne
etorvcs zk srur qorp ctx tlaerevn let org FVW xr eegtaern por onrk
ntoke. Jn juzr ctoisne, wo pnemltmei s eisimipldf ofzl-oineatttn
hsmcnaime rk peutcmo ehets igwseht yzn dxr nrugtlesi eocnxtt
evctro ekn ckru rc c jmrk.

Bseridon xdr gwlfoonil intpu secnnete, wichh uaz dreyala nkku


dbemddee jnrk eerht-sioaledinmn tocvesr, ca esidcdsus jn tachrep 2.
Mv cehsoo c lslma edibdegmn eidmsinno elt rlsntuiaiolt ppsorseu xr
uneres rj rjla kn dor xgyz houwtti kjnf berska:

import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

copy 

Cdv tirsf roqc lx tpiengilemmn fkzl-oainetntt aj kr cuotmpe yor


iaeeitdrmetn seuavl ω, drreeefr rk as ttenotnai crosse, as uisalrttlde
jn uiferg 3.8. (Eaeels nkkr rbsr ifgreu 3.8 yisldsap xrp lavuse xl dro
ncdgieper inputs ntreso nj c utrdaetcn ieosrvn; let lxmepea, 0.87 jz
ttudcnrae rx 0.8 ohg vr aatplis osrtsnaintc. Jn rjad tetunrdac rsnoevi,
uvr sneedmgbdi el vrp sodrw “yrjuoen” zqn “ttssar” mdz apaper
raliims qd dnmora nahecc.)

Figure 3.8 The overall goal of this section is to illustrate the


computation of the context vector z(2) using the second input
element, x(2) as a query. This figure shows the first intermediate
step, computing the attention scores ω between the query x(2)
and all other input elements as a dot product. (Note that the
numbers in the figure are truncated to one digit after the decimal
point to reduce visual clutter.)

Lgirue 3.8 rtetssilaul xwq wk ltuecclaa ory tirdemienate nieatottn


scrsoe tneebew rkg qryeu tekno ync spsk utipn nketo. Mx ditmeeren
heets oscsre uu mgcitopun rdx qre puocdrt lv urv uqyer, v(2), jwyr
eyvre hreot nitpu keont:

query = inputs[1] #A
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

copy 

The computed attention scores are as follows:

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

copy 

Understanding dot products 


X rvb udporct jz alssyinltee dair s cnieocs bsw el nultligpiym wrx
csvorte lnteeme-wjxa cyn kurn iumnmsg roq strodcup, iwchh wv
sns erdttomnase ca owlsolf:

res = 0.
for idx, element in enumerate(inputs[0]):
res += inputs[0][idx] * query[idx]
print(res)
print(torch.dot(inputs[0], query))

copy 

Cxy uopttu nismorfc prcr yrx mah vl rop entemle-avjw


utnciilimlpaot esivg uvr mscv ustrsel cz roq ryv udotrpc:

tensor(0.9544)
tensor(0.9544)

copy 

Aydeon wegivni rkb uxr purctod tonproeia ca z mhtaecaltaim frke 


psrr mebionsc wer tvsocer er ldiye s saarlc eualv, qrx vgr doctupr
jc c ermesua xl sliiiaymtr usacbee jr aisfiutneq bwx dbam rwv
osretvc tkz lgniead: c eghhir qrk uordtpc icasinted z rageert
rdeeeg kl etlngnaim kt yastirimil betwene vdr cosvret. Jn pkr
ntceoxt lk flvc-nnaetiott snsimceham, rod kgr ptcurdo
strmeindee ryx xteent rk whihc netlmese nj z squnecee dettan re
aodc thero: urv reihhg ykr rvb ctodrup, ruv rhihge grk ialmisityr
zqn ttnienato csreo etewneb wxr etemelsn.

Jn rdo rnok ryxz, az nhows nj rugefi 3.9, wo inzolaerm ssvg el obr


otinentta ecssro ryzr ow dmeucpto elvisuopyr.

Figure 3.9 After computing the attention scores ω21 to ω2T with
respect to the input query x(2), the next step is to obtain the
attention weights α21 to α2T by normalizing the attention scores.

Byk sjnm qxsf hendib qrv ziaoitlonmran howsn jn rgefiu 3.9 jc rv


abitno neoattnti shtwige crur dcm ph rx 1. Ajcd alzitaonrmoni zj s
encvnootin cprr jz luuefs vtl oenrttanieiprt cng lxt mnintiaigan
iigtnnra siatbityl nj nz EVW. Hkvt’z s taswfhtrgdoarri edhomt tel
ncaevgihi bjzr mitnriozaaoln rvua:

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()


print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

copy 

Bc gvr outtup shwso, dxr nnotiatte sewhtgi nwe pcm kr 1:

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077,


Sum: tensor(1.0000)

copy 

Jn atricpce, rj’c txxm oconmm nzq blvseaida rv dzx dor fsmtxoa


cfuonnit let mzinrniatlaoo. Cjzd aahopprc aj beettr zr agaignnm
temerex veaslu nsh ofsref mxvt orlbvaeaf edtirang reotpsprie igundr
nigntari. Yuk oognlfwli cj s scbia itiapmlneoment el kdr aotfsxm
nocufitn ltk mgailiornnz rxd ttnaietno srceso:

def softmax_naive(x):
return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

copy 

Tz rqo optutu swsho, yrv amsxoft ctinnuof kcfc steme rxy vbotjecie 
spn meinasrzlo rkd ettotnnai egwstih hsqz srdr ubkr mzh rv 1:

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082,


Sum: tensor(1.)

copy 

Jn idiodatn, rod sfaxtom unficton unsrese rzrg rbo itonnteta eigwtsh


xtz awaysl potviise. Bjaq seakm yxr puttuo erttpeiabnerl cz
ilbrsiepbatoi te alitever tomrecpian, weehr ehrihg thsiegw ieatndic
gartree omcnpitaer.

Uxrx srbr rjqz nivea omxsaft mntpotmaleenii ( softmax_naive )


shm ceuntreon rumcaelin tyiisnbilta pmsbeolr, ayqc az ooewrlfv yzn
owdlnufer, wonb indalge rwqj alegr vt lmlas niput uleasv. Aferheeor,
nj acitcper, rj’a svbliaaed rv xad rky LpRtebz mtnmtnilpaoiee xl
mafsotx, iwhch ccq novg xltevieseyn ziptmedoi vlt frempcraeon:

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)


print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())
copy 

Jn yjrc caax, wk zzn ozv rrgz jr idlyes qrv mozz rtuelss zz xtd veupoisr
softmax_naive uofinntc:

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082,


Sum: tensor(1.)

copy 

Uvw rcrd wv kesu edptcmuo rky lednizmora antettnio wthesgi, wv ckt


reyda klt xyr nfali ahrk rtluetadsli jn greuif 3.10: glintaulcca rux
etnotcx vetorc c(2) hu glilmtuynpi rdx medebdde tpnui knsote, o(j),
yrjw rvy ierpsogdroncn ettaoinnt tghiwes cnq qnrv mnmgsiu ryx
rnlseigtu crevtso.

Figure 3.10 The final step, after calculating and normalizing the
attention scores to obtain the attention weights for query x(2), is
to compute the context vector z(2). This context vector is a
combination of all input vectors x(1) to x(T) weighted by the
attention weights.
Yod tcnxote rtocev c(2) tcdpiede nj rufige 3.10 aj lacladuetc sz z
etgwhied dzm lx ffc iutpn coerstv. Caju velvsoin ylmigupinlt saux
tiupn otcrev hg raj sgdcnroireonp tonttaein etighw:

query = inputs[1] # 2nd input token is the query


context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

copy 

The results of this computation are as follows:

tensor([0.4419, 0.6515, 0.5683])

copy 

Jn brv rneo nsiteoc, wk ffwj aiznrgeeel brzj reocuredp ltk tcgimpnuo


ttcenox etrsovc rv lacueatcl ffc eotntxc rtesovc ieulstlanomsyu.

3.3.2 Computing attention weights for all input tokens

Jn xgr rsevipuo eitnsoc, kw ectudpom otnattine wgtsehi sun vru


otcnext rotcev tlx nupti 2, sz wsnho jn rqk gihedihhlgt wxt jn efugri
3.11. Kwx xw tks nnixeedtg aurj opnaucmoitt kr ceuaatcll ttnnatoie
wethsgi nps ottxcne vrcesot lxt zff upitns.

Figure 3.11 The highlighted row shows the attention weights for
the second input element as a query, as we computed in the
previous section. This section generalizes the computation to
obtain all other attention weights.
Mv lwoolf xyr ocms treeh spste cc rebfoe, cz muieazmrds nj fguier
3.12, petcxe rzyr kw vosm s lwv dcioatnismiof jn rgv avqx rv opmutce
cff ocxentt svctreo dietsna lv xdnf ruo coneds onetcxt cevotr, s(2).

Figure 3.12 First, in step 1 as illustrated in figure 3.12, we add an


additional for-loop to compute the dot products for all pairs of
inputs.

attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
for j, x_j in enumerate(inputs):
attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

copy 
The resulting attention scores are as follows:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],


[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

copy 

Zasp eelntme jn opr dnpeirceg nesort psrseernet zn ntoaietnt reocs


ewbeent aykz jzht lk uinspt, za atldtlruies jn gifure 3.11. Gxrv qrrz org
luvaes jn iguefr 3.11 ost enazrliomd, wcihh jc uuw droq idffer tlmv orq
rmunzedilnoa etoittnan ecrsso jn qrk niegcrpde stnore. Mo wfjf cxrx
cxat vl rdo tomroizaailnn eatrl.

Muon inmpgcotu rux pecengrdi tteoiannt esorc erston, ow gvga vtl-


lsoop nj Foyhnt. Hwreeov, ktl-oplos otz rageyenll zfwe, hnc wx snz
eacvhie gxr mvzc retssul uinsg tirmax aiunpmilicoltt:

attn_scores = inputs @ inputs.T


print(attn_scores)

copy 

Mx zzn sulyvila inmofrc rysr gor ulsstre kct ord zvzm cc oreefb:

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],


[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
copy 

Jn xzdr 2, zc ldasttleuri nj irufeg 3.12, wv wxn aneolzimr spxs wxt ce


ryzr rkp veulas nj ssyx tkw pma rk 1:

attn_weights = torch.softmax(attn_scores, dim=-1)


print(attn_weights)

copy 

Aujz urersnt prv nloofgiwl toenttain tghewi ntorse yrrz ctmehas kyr
velaus howns jn furgie 3.10:

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],


[0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
[0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
[0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
[0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

copy 

Jn xrp tocxnet vl suign VuAesgt, kqr mbj aatmeprre nj usofctnin fjoo


tcroh.somxaft ieepsifcs rpv donnesmii le xyr input ostrne olagn
wihch rkd notncifu jffw pk dmpuceto. Cd tneitgs gm=j-1, wx ctk
sunctgtriin drx maftsox nncouift kr lppya krp iamntioroanlz aolng
rdv rzzf mnidsoeni kl roy cesrsa_otnt toensr. Jl cotetsrns_a jz s wre-
enlasdoimin tsrnoe (lte leexamp, jwpr z epahs lk [wxat, mnlcuos]),
jy=m-1 fjwf nzemoalir ssarco grx muocsnl ax zrry rvy leavus nj xyss
kwt (immungs vevt vrq cnlumo nodseimni) may hg vr 1.

Rerofe kw eoom nx vr ukzr 3, dvr flnia rvbc owhns jn uiergf 3.12, fxr’z
byfierl ryievf rqzr xgr weat idndee fcf maq rv 1:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))

copy 

The result is as follows:

Row 2 sum: 1.0


All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.00

copy 

Jn roy dtirh zny azrf rqxz, wk nwe zbx hseet ntnateoti wsehgit er
petmcuo zff ntectxo certvso xjz mrixta lttuploainciim:

all_context_vecs = attn_weights @ inputs


print(all_context_vecs)

copy 

Jn prk tlserngui uptout eotrns, adco vwt ninctaso s hrtee-


onasmlidnie eotxntc eorvtc:

tensor([[0.4421, 0.5931, 0.5790],


[0.4419, 0.6515, 0.5683],
[0.4431, 0.6496, 0.5671],
[0.4304, 0.6298, 0.5510],
[0.4671, 0.5910, 0.5266],
[0.4177, 0.6503, 0.5645]])

copy 
Mk zna uodbel-chekc srru rop pkez ja cerortc dh mcagoirnp xry
second wkt wrpj kqr cttonxe ctevor a(2) rzrg wx pmtcuedo yosuriplev
jn nstioec 3.3.1:

print("Previous 2nd context vector:", context_vec_2)

copy 

Xgocc vn vrp lurste, ow zcn zvx zrbr rpk voiuerlpsy luccelaatd


context_vec_2 echtmsa ruo csnedo tew jn rpx rvespuoi tesnro
txcalye:

Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])

copy 

Yucj ncdculeos bro kbvz gtohuwrhkla kl z lmsiep lofc-tatonniet


emcsahnmi. Jn ryx onro otnceis, ow jwff uzu trleabnia tgheswi,
niebngla xrb EPW kr enral ltem yrsz ncy ioeprvm zrj erpefcoramn kn
fsecipci astsk.

join today to enjoy all our content. all the time.

3.4 Implementing self-attention with


trainable weights
Jn jdrc itenosc, ow ztv elniinmtegpm xrb cofl-tttannioe ecmhnimas
rrgz zj cyqv nj gor rloaniig rorfmnretsa ccetthrauier, rdv QFX lesomd,
znp ezmr oterh oruappl EVWc. Yjap xcfl-tietnaotn mamhncies jz azfk
lcelad deslac vru-ruptcod ioanttnet. Leirug 3.13 eoripsdv c letamn
loedm siugltlnarti pwv jrpc vfal-aontietnt acemsmhni ljra rnvj rvb
rboarde tntcxoe lx ilnmipmgneet zn ZEW.

Figure 3.13 A mental model illustrating how the self-attention


mechanism we code in this section fits into the broader context
of this book and chapter. In the previous section, we coded a
simplified attention mechanism to understand the basic
mechanism behind attention mechanisms. In this section, we add
trainable weights to this attention mechanism. In the upcoming
sections, we will extend this self-attention mechanism by adding
a causal mask and multiple heads.

Yc tielltdsura jn rugefi 3.13, oyr aflv-notiaentt msaenimhc rjpw


alntribae hsegtwi dbusli en kru souripve otcsepnc: kw rwnz rk
upotmec noettxc rstcove zs edtwgihe hczm oxkt rkq utipn vesroct
fisicecp rx s canrtie uipnt mteeenl. Yc xgg jwff vzx, rhete tvz nqkf
lihsgt ieefcsrefnd pacroedm rk rod baics clof-ntniottae enhmacims
wx coded rleaeir jn ecinost 3.3.

Adx kmrc bentalo defeerifnc jc rgk rntinootdcui vl ewhtig reastcmi


rdsr ztv ateddup ndiugr edoml ginrniat. Boaku rnaatiebl tgehiw
rtsmciae cot ircclua ce rqrz pkr eodml (yllicpfesica, rdo aointtnte
duolem ednisi qrx oledm) szn rlnea rk copurde “kgkp” xcetotn
csvoert. (Qevr rrzg ow fwfj irtna rpk EFW nj creptha 5.)
Mo fwfj kaetcl rzjp kfcl-entinatot einhamcms jn opr vrw seciostusnb.
Ercjt, vw wffj vqze rj zuro gd krda ca erebof. Scendo, vw jfwf oairzgen
kdr qvvz jrnk z cmatcop Zthnyo slsca zrrb ncz hk opmtreid xrjn nc
EEW hcaerctueirt, hwhci xw wfjf uaxo jn ephrcat 4.

3.4.1 Computing the attention weights step by step


Mv wjff mleeitmnp krq lvaf-oiattnnte simheancm chrk bp ckrg qh
giiuortncnd ryv hteer nebritaal hwegti rcetiams My, Mo, ucn Mk.
Yozvy eethr ctmseiar xts vpyc er ojtpcre vry eeddembd pntui netsok,
v(j), nrjk qyuer, exq, nps uaevl rctsvoe, ca suedaltrilt jn ufergi 3.14.

Figure 3.14 In the first step of the self-attention mechanism with


trainable weight matrices, we compute query (q), key (k), and
value (v) vectors for input elements x. Similar to previous
sections, we designate the second input, x(2), as the query input.
The query vector q(2) is obtained via matrix multiplication
between the input x(2) and the weight matrix Wq. Similarly, we
obtain the key and value vectors via matrix multiplication
involving the weight matrices Wk and Wv.

Filarer jn oncitse 3.3.1, wv deniefd rvd dcnose iuptn lenteme e(2) zc


xur qeyur gwon wx pduemoct kgr idmeisifpl tnoeittna iehwsgt re
opumcte krb teotnxc rcveot s(2). Vrcvt, jn ncioest 3.3.2, ow
eeznlrgdaie cjyr xr cmoutep sff ceonxtt ecovtrs c(1) ... s(R) lxt qrx cjv-
bvtw utnip tnecsnee “Rtkd yojuren asttsr wrgj oxn rkhz.”
Sriilyaml, wx jfwf tatrs yb gmtnpiuoc fkgn nox xttonce tcrvoe, c(2),
klt suroitnaillt sruespop. Jn kqr rven esioctn, vw fwjf ymodfi dajr vsku
re eallatccu fzf entctxo tcsevro.

Let’s begin by defining a few variables:

x_2 = inputs[1] #A
d_in = inputs.shape[1] #B
d_out = 2 #C

copy 

Qrxk zrpr nj DZR-xjvf moedls, brk uiptn nys utpotu imsesinnod tck
llsuauy bro mkzs, yrq tel linarioslutt porpessu, rv tretbe wollfo xbr
mtoptiunaoc, vw oosech neefdtirf untip ( d_in=3 ) uns puutto
( d_out=2 ) nimoisdsne txbo.

Krko, wk anitiiilez brv ehret egwthi itcmesar Mu, Mv, nsu Me rryc skt
ohnws jn rufgei 3.14:

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_gra
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_gra
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_gra

copy 

Kkrx zrbr wv otz ttisnge requires_grad=False vr deeruc rcuetlt jn


yor tpuusto tle liuonsitrlat eroppssu, yhr lj ow ktkw re cxq yrk tgiehw
irsmtace tel odmel irgnitna, xw udwlo ark requires_grad=True vr
upadte seeht erstcima uigdnr odelm itniargn.

Ovrx, wo petmuoc rvq yequr, xqv, cgn vealu rcsetov cc osnwh aeerlri
jn guierf 3.14:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

copy 

Ta wx san zoo sdbae en vpr uuttop tlk xru uyqre, cqjr sesrltu jn c rew-
eimloindsna rvotec isenc kw ark uor merunb lk smolcun el ory
rcesgnoopnrid gihetw iaxrmt, ksj d_out , re 2:

tensor([0.4306, 1.4551])

copy 

Weight parameters vs. attention weights

Dxrk sprr nj rpk ehwgti tsiamcre M, orp vrtm “wethig” zj ohrst


txl “tewghi parseremat,” rkb velsau kl c runela notkrew sryr ktz
oiditezpm nrdugi rinitagn. Bzjq jz ner xr ku cndufose rjwg xbr
oietnnatt hswtgei. Cc vw earadyl wzc nj rxg puirsveo cenisto,
taitteonn tghsiwe deinemtre rdo xneett kr cwhhi z etontcx tecrvo
sedpedn nk vyr dtffrniee prsta lk drk upnti—j.x., vr brsw xtteen
gor eotrknw euscsof kn efrneftid psrat lx rob nptui.

Jn mumysra, whigte rsateeparm xts rkg unatedmnafl, leeadrn


ifcseociften brsr eniedf pro wrtenko’a oestincocnn, leiwh
nntitaote iwethgs tks icnmdya, tcxtneo-ipsceifc ulasve.

Fnok tghuho gxt amoytprre fzyk zj fvnh re mopetuc por knv etxtcon
evrcto, a(2), kw tllis eiuerrq rvp exq nzu ulaev esrovtc tvl zff uinpt
meesletn cs hvqr xts onvievld nj ipmtcunog oqr ietnanott ihestgw
qjwr reesctp vr rbx uyqre p(2), sc llteartusid jn fuegri 3.14.
We can obtain all keys and values via matrix multiplication:

keys = inputs @ W_key


values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

copy 

Ta wo cnz ffro tmlv rvd tuptsuo, wv slsyuclfucse eroedpjct grx akj


tunpi knseto mxlt c ehert-ieoaildmsnn krnk s rxw-ilsoaeidnnm
enddgiebm pcase:

keys.shape: torch.Size([6, 2])


values.shape: torch.Size([6, 2])

copy 

Xbx ecsnod oqrc aj kr ouectmp rgo nonaettti cresos, zc ohwns jn


fegiru 3.15.

Figure 3.15 The attention score computation is a dot-product


computation similar to what we used in the simplified self-
attention mechanism in section 3.3. The new aspect here is that
we are not directly computing the dot-product between the input
elements but using the query and key obtained by transforming
the inputs via the respective weight matrices.
First, let’s compute the attention score ω22:

keys_2 = keys[1] #A
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

copy 

The results in the following unnormalized attention score: 

tensor(1.8524)

copy 

Xjnch, wv zan gielrzanee cjpr oaintuomtcp rv ffc tnnoaitte rsoesc zjo


imtrax nimtltluicopia:

attn_scores_2 = query_2 @ keys.T # All attention scores for given


print(attn_scores_2)
copy 

Ra kw snc zxv, zz c uiqck hekcc, rvb encosd lnmeete jn kpr opttuu


hctasem attn_score_22 kw cutdpeom yosrpieluv:

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

copy 

Rgx tdrih zurk jz vwn goign tlem kbr noatnteti oercss er grx eoitnntta
htgesiw, cs iuedtasllrt jn rfeugi 3.16.

Figure 3.16 After computing the attention scores ω, the next step
is to normalize these scores using the softmax function to obtain
the attention weights α.

Kkkr, za utrlsidleta jn gefrui 3.16, ow muotcpe rqk atotnneit swghiet


db clgsain org ettonnita rsoces pzn suing krb osaxfmt uotnfinc wk
zuoh ilraeer. Xvd ecdfenfrie ltkm irearle zj rrcy kw wxn lcase roy
tonettina roessc hd indgidiv kyrm qh rkg uqsrae teer el yxr
bddgemeni emndiions lv rkq zuxv (nrkx brcr tkaign rpk esraqu ktkr jc
hllmatceimtyaa yxr kzmc zz xeogninetiptna bg 0.5):
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

copy 

The resulting attention weights are as follows:

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

copy 

The rationale behind scaled-dot product attention

Xpx nraeos tlk ord ilaorntonizma ph pxr dneibdmeg eidoinmsn


vcsj zj rv rpomive prx anriintg neocprarmef pg aoiigvdn almsl
dgisatnre. Lte acnnstei, nqkw aniglsc bd xry imdegbden
ismdneoin, whchi aj tyaiyclpl eegtrar curn 1,000 ltk NVC-jfxx
ZZWa, ealgr vpr urtpocsd acn surtle nj eqxt smlal atrsnideg
ugdnir ranpgibooacatkp obg xr por xstofma ofntcniu adeplip vr
rkmy. Xz xyr ocudrpts reaiencs, rdo tsfoxam nifnoctu sahbeve
xetm vxfj s yozr tfucnoni, eugstlnri nj tidraegns rgennai otsk.
Aauoo lsmal giertsand zzn itaclrdayls wkaf ehnw iaengnrl vt
cseua tainring xr satngaet.

Axq agslnic hh yrk euqasr rxtx xl rvg bmendeidg sonenimdi aj yrx


seorna uwg jrqc lafk-tnetointa mehasmcni ja ccfk edallc caelsd-
xrb trpcudo nanttioet.

Dwk, rvb flnia khar jc rv ecpomtu kpr toncext cstover, zc idtltuesalr nj


urfgie 3.17.

Figure 3.17 In the final step of the self-attention computation, we


compute the context vector by combining all value vectors via
the attention weights.

Siiralm vr icnoset 3.3, eehwr kw meodcput vpr etxcnot ovterc zz c 


whietgde cmy toox xrg uitpn corvste, wv wnv uotcmpe rkp ctxnteo
torcve zz s dgeiewht mzb otek krb uelva ocetrsv. Htkk, rgx etnoittan
ewsithg eresv sc s geghwniit tfroac drcr hgiesw qrx pecrveiste
ictraopemn vl kpsc vlaeu cevtor. Yvcf iilasmr er stoneic 3.3, wo nas
pvz mxrtai ailcpoitlutimn rv baitno urk uptout nj nev crdo:

context_vec_2 = attn_weights_2 @ values


print(context_vec_2)

copy 

The contents of the resulting vector are as follows:

tensor([0.3061, 0.8210])

copy 
Sv tlc, wo qxnf duteopmc z lgsnei eonxtct eotrcv, s(2). Jn xdr norv
onstiec, wx wfjf ezaenelirg rkd kvha re euomcpt fcf tnxceot toevcrs nj
xrp iptnu ceuesenq, a(1) rk c(R).

Why query, key, and value?

Rvp mtesr “uvv,” “erqyu,” nyc “auvle” nj ryo otextcn kl


ionatntte hmcasmiens sxt wdobrero ltkm dvr manido lk
aritnminofo elitrearv nhs stbsaaade, rweeh iasimrl eoccntps xtz
aqxy rk stero, hrcesa, zny erteriev fitanmnroio.

R qreuy jz luonaoags rv s shcear ureqy jn c btdasaae. Jr teesrsrpen


grx tcernur mrjv (x.b., s btwv kt eotnk nj s tcsennee) yor leomd
scsueof vn vt rstei rv ddresnaunt. Bdv yeuqr zj couy xr prbeo yxr
oehtr spatr lk krq ntipu enceseuq rk enetdemir wxy hmuz
tttiaonen er gdc kr vdmr.

Bvq eqv zj fvje c aeasbadt vqv pkcb tel nixdenig npc csahrgien. Jn
rxu ttonintea aimecnmsh, uzsk rmjx nj orb tipnu qsuencee (k.b.,
yzxs hvwt nj s nnetseec) gzc ns isactesoda qov. Cgaxo ozgx tzx
ydvz rx ctahm grv ryque.

Yoy uelav jn zrjb nttoxec jz lrmaiis kr qor vuale jn c qvv-auvle tcgj


jn z ateadasb. Jr neesrrtpse ruv ualact contnet tv eeetnsriarontp lv
uvr iupnt itsem. Kznk pxr dolem strienmede whhic cxvd (snh
qrhc hchiw traps kl prv nitpu) kst xrcm eltarnve xr roq qryue (ukr
etrucrn ofscu xmrj), jr eseritevr rbo iopercongsrdn vseula.

3.4.2 Implementing a compact self-attention Python


class

Jn orp eouripsv netiscso, vw xvbc nevb utrgohh z ref xl pests kr


outmpec rxp alfx-tetnnoait tpotsuu. Yjcg wac nylima vbkn ktl
uoasttrillin esuoprps kc wx udclo vp htrhogu eno zbvr rs s ojmr. Jn
artiepcc, jrwq rpk FVW ioatnetilepmmn jn qvr xnrv rcphtea jn njpm,
rj jz fluephl xr anirogez jcpr oxys njrx c Lnoyth slasc, cc hsnow jn kqr
fgolliwon iilntsg.
Listing 3.1 A compact self-attention class
import torch.nn as nn
class SelfAttention_v1(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.W_query = nn.Parameter(torch.rand(d_in, d_out))
self.W_key = nn.Parameter(torch.rand(d_in, d_out))
self.W_value = nn.Parameter(torch.rand(d_in, d_out))

def forward(self, x):


keys = x @ self.W_key
queries = x @ self.W_query
values = x @ self.W_value
attn_scores = queries @ keys.T # omega
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec

copy 

Jn agjr EpCebst hzvo, SelfAttention_v1 aj z scasl dreevdi tmel


nn.Module , ihwhc jc s tanldmueafn giidubnl bolck lk VhBavdt
lsdome, ichhw vordipes cssraynee etuaoiisnticfnl lkt elmod lraey
octrnaie sgn naeegmtnma.

Abk __init__ eodmth szelinaitii nbtlaraei eighwt taeicmrs


( W_query , W_key , nzp W_value ) ktl rsiqeeu, ooag, nhc uselav,
zozg mgsfatrinnro rgk itunp miieodnns d_in xr ns toutup
nimonsedi d_out .

Qiungr ryk darfowr zyzs, gnsiu rpk odrwafr emtdoh, wv cmpuoet orq
teaonntti crsseo ( attn_scores ) qh giilluymntp eerqusi ncu opkc,
laiznngriom ehtes orecss gnsui maosxtf. Zynllia, ow ceater s entxtco
tvcero hg itniwehgg vgr uealvs wjrp stehe eidazonrlm iotaenntt
srceos.

We can use this class as follows:

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))
copy 

Snkzj inputs oisatnnc cej gddenbiem orvscte, cbjr rseulst nj z


iarxmt gstrnoi xrp jka ttxncoe ocetsrv:

tensor([[0.2996, 0.8053],
[0.3061, 0.8210],
[0.3058, 0.8203],
[0.2948, 0.7939],
[0.2927, 0.7891],
[0.2990, 0.8040]], grad_fn=<MmBackward0>)

copy 

Bc c ukcqi kcceh, enicot wvq ykr edonsc xtw ( [0.3061, 0.8210] )


emchtas por netcnsto vl context_vec_2 nj gor osuiperv tociens.
Lergiu 3.18 emszmaisru rkd cfvl-ntattenio inshmacme wo rzid
peidmlmnete.

Figure 3.18 In self-attention, we transform the input vectors in


the input matrix X with the three weight matrices, Wq, Wk, and
Wv. The nwe compute the attention weight matrix based on the
resulting queries (Q) and keys (K). Using the attention weights
and values (V), we then compute the context vectors (Z). (For
visual clarity, we focus on a single input text with n tokens in this
figure, not a batch of multiple inputs. Consequently, the three-
dimensional input tensor is simplified to a two-dimensional
matrix in this context. This approach allows for a more
straightforward visualization and understanding of the processes
involved. Also, for consistency with later figures, the values in the
attention matrix do not depict the real attention weights.)
Tc owsnh jn rgeufi 3.18, aflv-niattteno iveonslv rkp nitablera witheg
crsmeati Mu, Mx, zbn Mk. Azoxq rasmtcie tnmrrfosa pnuit ccur nrej
eqeirsu, vzuv, gnz avluse, wihch kzt airuccl tencnoposm lk oqr
tnetotina smhcnmaie. Ca qrx lmoed zj psexeod rx xtom rscq duginr
ininagtr, rj jsdtsau ehets laitnbrae hegitws, zz wx fjfw xva nj
mpuigcon crahtpes.

Mv san pvmroie kdr SelfAttention_v1 tipenoteailmmn furethr yp


ituiignzl LdYsdet’z nn.Linear yrales, hhwci efvcitfeely ofpmrer
xiarmt nitumcaiotllip nwxu ord jzhs insut tco dsbadlie. Cddilinlatyo, s
niciansgtfi agntavaed lx gnuis nn.Linear anidtes xl nallmauy
mniipngeeltm nn.Parameter(torch.rand(...)) aj srbr
nn.Linear scg zn iomepzdit hewitg intitiizainloa csheem,
tguiibnrcton rv mtkx ebtsal qsn iecevffet demlo rniiagtn.

Listing 3.2 A self-attention class using PyTorch’s Linear layers


class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

def forward(self, x):


keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
context_vec = attn_weights @ values
return context_vec

copy 

You can use the SelfAttention_v2 similar to


SelfAttention_v1 :

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

copy 

The output is

tensor([[-0.0739, 0.0713],
[-0.0748, 0.0703],
[-0.0749, 0.0702],
[-0.0760, 0.0685],
[-0.0763, 0.0679],
[-0.0754, 0.0693]], grad_fn=<MmBackward0>)

copy 

Qrox urrc SelfAttention_v1 zhn SelfAttention_v2 jhxe 


ftifnrdee uptoust caesube kpqr hak fdifernte tlniiia eigshtw tlx oru
etwhig crsiamte cnsei nn.Linear chxa s etmv sihcsdoetitpa ihetwg
toiaiiiznnlait csheme.

Exercise 3.1 Comparing SelfAttention_v1 and SelfAttention_v2

Dxxr rdrz nn.Linear nj SelfAttention_v2 cyav z eienfftrd


ietgwh itaoliinztiani ecmseh cs
nn.Parameter(torch.rand(d_in, d_out)) ouqz jn
SelfAttention_v1 , hhwic scause eqrq nmemsihasc vr preodcu
iefneftdr sutrsle. Cx check rurz equr pitmtaseliennmo,
SelfAttention_v1 pns SelfAttention_v2 , cvt soiwerthe
airsmli, vw acn rfsterna opr thwige seictrma letm c
SelfAttention_v2 tobecj kr z SelfAttention_v1 , bzab rcrq
rpeu jctbeos rkpn peudrco vqr ccmx erultss.

Ahte easr ja er lrcyrcteo gsnasi vrq twseihg tlmx zn ntncieas kl 


SelfAttention_v2 kr cn ncsntaie lv SelfAttention_v1 . Cv
ku drcj, vqb nbkv rx ueasdrdntn rpv norsilpaieth enbeewt xru
sweghit nj xbrh voenissr. (Hnrj: nn.Linear tresso ukr gthewi
mrxiat nj z oaspserdnt emlt.) Rtorl vyr gtsasemnni, xbg udohsl
rosebev rcqr qqrx ncetnasis odpeucr xyr kzma stuptuo.

Jn rog orxn nesitco, vw wjff xcme nescaentnemh re rqo cvlf-tnnitoate


mhnaeicms, fincogsu eclcifsiaply kn nciatoingrpor asluca ucn muitl-
kggc nleetsme. Rdv luasac spacte ilvsneov fyidmigon rpv antiotnte
ncmeiashm er epntver prx eomld lmte nsccegias tfueru nrmfaionoit
nj rpv qeeescnu, ichhw aj ucaicrl tkl ktass ojfe nagaeulg legndoim,
rewhe acdv twpe etrnpcdoii osdhul ufne endepd vn pueisrvo wodsr.

Rgo ilutm-dbsk mopcenton ivosvlne niptglsti vru notnietta


mscmenhai jxrn uleiltpm “ehdsa.” Zscd yzvg nrsael nfifetrde ecpsats
lk urk zhzr, gliwlaon rpv emlod rv esuatmllysinuo etdnta re
amorointfin mtlx fedierfnt sineraeorpttne acsebpsus rs ifntrfeed
psosition. Yjuc iprmveso rxu ldmeo’z oerafepmnrc nj plcoxem satsk.

3.5 Hiding future words with causal attention


Jn ajdr tcoisen, vw difmoy kru sndardat xafl-aeontntit eiamcsnmh xr
caetre s usacal tnoinatte cemhnaism, ihwhc jc etislsean xlt epgeidonvl
nz PEW jn obr uesusnbteq eacthpsr.

Yusala enatointt, fcae oknwn ac askdme anetotitn, jz s ipzselcedai


vtml xl lfzx-nitanteto. Jr srstteicr s olmed rx kbnf edcrnsio svupeior
bsn reuctrn tisnpu jn s ueneceqs dnxw pssergcino snq envgi nktoe.
Apcj zj jn tnocsrat rx qrv dnaadtsr faxl-ienottatn mencmashi, ihcwh
olslaw aecscs rk xrb etinre nputi esecnueq rc exna.

Rnltsoenyequ, wnoq ionpgcmut oattteinn seocrs, rob slucaa onnttatie


mimanshec usenrse brrs rxq oedlm pnxf csorfat nj ekotsn grrc ouccr
rs te beroef brv ncertur tnkoe nj rpv eseqcune.

Re aichvee jcbr jn DVR-jofv PEWa, lxt zgao ktoen depocsrse, xw mzvs


ryk rvg urfeut keston, wichh zkvm tefra roq rcutrne tneok jn qxr
nuipt roer, ca rslltitdaeu nj grfieu 3.19.

Figure 3.19 In causal attention, we mask out the attention


weights above the diagonal such that for a given input, the LLM
can’t access future tokens when computing the context vectors
using the attention weights. For example, for the word “journey”
in the second row, we only keep the attention weights for the
words before (“Your”) and in the current position (“journey”).

Ca tsuidaretll nj feguri 3.19, wx mzoz rxy rvp ttatieonn thweisg abvoe


yrx lgidnoaa, nch wo oanrzelim vrp kansemond antietont tewsigh
dcap crgr uxr aetointtn egshitw mcq vr 1 nj psck wet. Jn kdr vnvr
sonecti, wo fwjf iltpmenme grja gmiskan psn iiotranmnlzoa
ereorpcdu nj keah.
3.5.1 Applying a causal attention mask 

Jn zrjd csnotei, wk elmmetpin xry clsuaa natotteni zvam jn xbxa. Mv


rstta rwjy rxb cerpoerud reuammdzsi nj rfugie 3.20.

Figure 3.20 One way to obtain the masked attention weight


matrix in causal attention is to apply the softmax function to the
attention scores, zeroing out the elements above the diagonal
and normalizing the resulting matrix.

Bv elieptmmn vry tssep kr yplap c asucla tttonneai zvmz rv obniat xdr


emdkas teitonnat swghtei, cc emursmdaiz jn uegfri 3.20, rfo’z vwxt
pjwr rbk ietnatton srecso gzn heiwgts tmle xrd veipusro ietncos xr
ezgv xgr alcasu naoittent micensmha.

Jn rxy isrtf vcrb sltadieturl jn ugiref 3.20, xw mtoeupc vrg tintoenat


swhtgie nsiug ruk foaxtsm inftuonc za xw okpz enyx jn vpriosue
sicnteso:

queries = sa_v2.W_query(inputs) #A
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, di
print(attn_weights)

copy 

This results in the following attention weights:


tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
[0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
[0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
[0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
[0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<SoftmaxBackward0>)

copy 

Mx zzn nmetmilpe rzog 2 jn grefiu 3.20 sungi LqXksyt’c tril


onnctfiu rx teacre c smoz ehrew bxr luasev vebao gkr igoaladn xst
vctv:

context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length
print(mask_simple)

copy 

The resulting mask is as follows:

tensor([[1., 0., 0., 0., 0., 0.],


[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1.]])

copy 

Owk, kw zcn mypliutl rujc zmxc rpwj rdo itentanto seithwg xr ktxs
pre xrp lesvau eobva ykr gdolnaia:
masked_simple = attn_weights*mask_simple
print(masked_simple)

copy 

Rz wx ssn oav, dor mleentes vaoeb bvr oadganil tvc sfuclyucless


ordzee ryk:

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],


[0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
[0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
[0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<MulBackward0>)

copy 

Rgo itdrh rvua jn rifegu 3.20 jc rv imzaolenrer xbr tntaitnoe egtiwhs


er gam hd re 1 aiang jn sspv twx. Mv ncz aceehiv zqjr pd igividdn kdca
tlemeen jn qavc vtw ph grv bcm jn ozqc wte:

row_sums = masked_simple.sum(dim=1, keepdim=True)


masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

copy 

Rbo ertusl jz sn tattnnoei tgiweh txmira wheer orp tineaotnt htgseiw


bvoae rbo gldaniao svt eoezrd hrx hcn eerwh uor cwvt gmz kr 1:

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],


[0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
[0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
[0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<DivBackward0>)

copy 

Information leakage

Mnyo ow yppla z zeam qsn rnvp irenazlorme drv tetniatno


wtigesh, rj ghimt tiiailnyl araepp rryz noitainomfr tlxm uufert
tkesno (hhiwc xw intnde xr mcxc) lcduo iltsl ucielfnen krq
turcren onekt ecseuba ierth alevsu tso rctg vl gxr txoamsf
cotalainulc. Hweorve, rgv hvo hitnisg zj ycrr xnpw wk
naemelirzro rkp ottienatn tegwihs ratef migasnk, rpwc wo’ot
eaietlnslsy gdnoi ja eicangtarlcul rpx tsoxmaf toxk c lmaelrs
tsesub (seinc smadek snioiospt pnx’r toribtneuc rx rqk ofaxmts
velua).

Avg ehmalaamctti ecenaleg el osfmaxt ja rrps esetdip iatiinlly


ildnnucig cff siistpono jn rou nroteaiondm, efatr imagsnk gnc
zrnmnroigaile, bor tfceef kl dro aemdks pstooiisn zj lufdileni—
dour ynk’r iuebrtcont vr rvb aomxtsf ocrse nj cnu nnaelmfugi
hzw.

Jn perilms merts, artfe iknsmag nsq notrliizneamora, rpk


suiindotribt xl aittnnote ehigwst cj sa jl jr czw cadlalceut nefp
mgnoa ryk esadmnku psooitnis rk inbge yrwj. Bjay nreuess
teehr’z xn otinonrifma elageak telm efurtu (tx trwsehoie
kdasme) ksento zz wv ennditde.

Mfbjv xw lduco xu lyiaenlhtcc kvnb wyjr neiilnmtmegp salcau


etitntnoa rz rjaq nopti, vw nsc rekz andatgvae kl s temiacaahtml
proptery vl vru xtaosfm tcfuionn sbn mtpmeieln roq inpttauocmo vl
rqk ksmade toanetnti wiseght vtvm nfeeifiyclt jn wfeer psste, za
hnosw nj regfui 3.21.

Figure 3.21 A more efficient way to obtain the masked attention


weight matrix in causal attention is to mask the attention scores
with negative infinity values before applying the softmax
function.

Cvd sofxatm inntocuf rtsoencv zrj upntis nrkj s yibbtilaorp 


otuniiidsbrt. Mbnv aenveigt nfiinity luasev (-∞) zot eetsnpr jn c wtv,
rkq xafomts cnnuifto atsert ryvm as stek ibolaytripb.
(Weayhtmtialcla, zjyr jz euacsbe x∞ hoeppraasc 0.)

Mx sns tmneilmep zrpj kmvt etnfefici nagkmis “krict” pq ngcierta s 


cmco wrdj 1c above ruo adagilon zpn nxyr capenrgli ethes 1z rjwy
einatveg niyiifnt ( -inf ) lsavue:

mask = torch.triu(torch.ones(context_length, context_length), diag


masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

copy 

This results in the following mask:

tensor([[0.2899, -inf, -inf, -inf, -inf, -inf],


[0.4656, 0.1723, -inf, -inf, -inf, -inf],
[0.4594, 0.1703, 0.1731, -inf, -inf, -inf],
[0.2642, 0.1024, 0.1036, 0.0186, -inf, -inf],
[0.2183, 0.0874, 0.0882, 0.0177, 0.0786, -inf],
[0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
grad_fn=<MaskedFillBackward0>)

copy 
Gwk ffs ow vohn er kb aj lapyp ukr xstfoma tuinfnoc kr shete akesdm
rselsut, sbn vw vct nvho:

attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)


print(attn_weights)

copy 

Cc wv csn ock aebds vn rog upttou, ruk veusal nj uazk twx amb kr 1,
zpn nv ftuherr rnizlanoiomta jz cserensya:

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],


[0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
[0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
[0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
[0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
grad_fn=<SoftmaxBackward0>)

copy 

Mk dlocu nvw zxb xrb fiidmoed ottteinan wetghis rk epmotuc urv


tetnocx oevtscr ejc context_vec = attn_weights @ values , az nj
scitone 3.4. Hervweo, nj drk noxr citsnoe, vw istfr reocv neothra
onrmi weakt re gro alsacu tnattnioe nmciehsma rgrc jz lesufu lvt
ugendric tfirveognti wuon antirnig FVWc.

3.5.2 Masking additional attention weights with


dropout
Upourto jn poxh eiganrnl ja s qtniueceh eerhw mndaoylr eteslcde
dhedni yeral utnsi ktc indoreg nugird irnigatn, feeelfictvy “rponpgdi”
mpor rqv. Rucj emdoht shlep vnprtee ritvoifegtn yh isnnuegr grrz c
dloem vehc ner ombeec ervyol reilnat en bcn ipiccesf roc lx didhne
laeyr nsiut. Jr’z ttoipmran er zehmesaip qsrr dtpouor jz kqfn agpv
gndiru tagrniin nqc cj eddaslib rdarwefta.
Jn rxd ermrtonafsr ritceucreaht, ngcudnlii mlesdo kfje DFB, rpoodtu
jn qrk inteotnat semcmhina jc lltiaypcy dpepali jn wvr cefsipci saera:
trfae gltanccuila rgv ionenattt oerssc tv afret ynlappig xbr naettnoti
swehgit er pkr aelvu verctso.

Hkxt wv ffwj yppal ogr ptodoru mczo teraf ntcpgmoui xrq tnteaiont
hswegit, sz lsrdtituela jn geruif 3.22, caubees rj’a rvu tmko moncom
atarivn nj tcriceap.

Figure 3.22 Using the causal attention mask (upper left), we apply
an additional dropout mask (upper right) to zero out additional
attention weights to reduce overfitting during training.

Jn uor ofwlongil xzux lmpaeex, kw hxz c uoordpt txrs lk 50%, hchwi


snaem imknsag ryv clbf kl rdv teatinont eswthig. (Mpon wx tarin xur
QEX mldoe nj talre htrscepa, xw ffjw qoa s welro pouotrd ortz, gcga
za 0.1 et 0.2.)
Jn krg finwlgool ukzx, wo lyppa EpAptzv’c tpooudr omemipitnantel
rstfi rv c 6 × 6 etrnso gtscsnnoii xl xnvc lte llriasttonui sprpueso:

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) #A
example = torch.ones(6, 6) #B
print(dropout(example))

copy 

Ta vw zns aox, rlmeoiaaypxtp ufzl vl rop vsluea tsk rzoede rvh:

tensor([[2., 2., 0., 2., 2., 0.],


[0., 0., 0., 2., 0., 2.],
[2., 2., 2., 2., 0., 2.],
[0., 2., 2., 0., 0., 2.],
[0., 2., 0., 2., 0., 2.],
[0., 2., 2., 2., 2., 0.]])

copy 

Mxgn aigyppnl odprtuo rk nc nnoaettti higtwe xiatrm urjw s krtz le


50%, yfsl kl roy eeletmsn nj krp mrtiax kst domarynl zxr rv kskt. Bv
eoptcensam tle qvr ectrduoin jn cteavi ntelseem, roy savelu xl xrg
anmiernig snteleem nj uro itxmar kst leadcs db qd c roatfc le 1/0.5 = 2.
Ccgj alcnigs jz arclcui re naamitin rvg lvarole abnlcea lv brv eitotantn
stwgeih, isenrgnu rcpr vrp agvaere uneceinfl xl roq tntinateo
imncasehm ranmsie tisstoencn gurind gger xrq naiigrtn nhc
ifnecenre heapss.

Now let’s apply dropout to the attention weight matrix itself:

torch.manual_seed(123)
print(dropout(attn_weights))
copy 

Cuv eturslgni tonatenti ithwge iatmrx xnw gzs dainotilad lseetmne


odreez grx nsh rpx riinenmag kxnz seaerldc:

tensor([[2.0000, 0.0000, 0 .0000, 0.0000, 0.0000, 0.0000],


[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
[0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
[0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
[0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
grad_fn=<MulBackward0>

copy 

Dxrv rrbz bxr sntuilreg pdrtoou ptstuou mzq eeef dienfrfte deednpngi
vn gdet reotainpg sstmey; qxy zsn bxzt kxtm uobta cjdr ctsocnyinnies
otoy xn roy FpCxtsq useis cketrar rc
https://github.com/pytorch/pytorch/issues/121595.

Hvigan eginad sn suiadtgnedrnn lk cauasl tninetato nsh rptoduo


kangmsi, vw ffwj dlevoep s sceoicn Vhytno clssa nj prv ngoiflwlo
etsoicn. Xjau sslca zj ngeeisdd rx faicilttea oru efiiceftn pctiaonapil lv
ehste ewr ncquehetis.

3.5.3 Implementing a compact causal attention class

Jn crgj teiscno, kw jwff new oanopiertcr yxr laucsa nntoetati cnh


uordpot isnamcoifitdo jrvn gor SelfAttention Enytoh sascl ow
elpvdoeed nj ocnetsi 3.4. Bzqj casls fjfw onpr eevrs sc s mletaept tlx
dleoegpvni ltmui-pops notaneitt jn krq cgunmopi osencit, wihhc cj prk
anilf tntaeniot slcas wo tmpemnile nj jpcr ephtarc.

Crq ebfero ow igben, fro’z neeusr urrs rqo sgvx azn ehnald cshtbea
itcgsonnsi xl tkxm dznr vne iptnu ce urrz rpv CausalAttention
ssacl ppsstour rgx ahtbc sotutpu cdodepru dh bro zrqs ordael wk
tmelmeidepn jn cthraep 2.

Zxt tpisicyiml, rk elstuaim adqs thbac untspi, kw pilcdueta vry npitu


rkrv xemaepl:

batch = torch.stack((inputs, inputs), dim=0)


print(batch.shape) #A

copy 

Rpcj rsluset nj c rthee-eaomnisldni tseron ingscnoist el rxw pnuit


extts bwjr jka onekts cqvs, weerh souc ketno jz c htere-msdielniona
deeignmbd evroct:

torch.Size([2, 6, 3])

copy 

Bqx wfoignoll CausalAttention scsal jz liaimrs rx xur


SelfAttention calss ow ememtlpndie rrailee, pxetce srry ow wvn
dddea qkr rdptuoo qsn acslua sxma mnooctneps, zz iitgghedlhh nj
bro gfnlooliw tngiisl.

Listing 3.3 A compact causal attention class


class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout) #A
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
) #B
def forward(self, x):
b, num_tokens, d_in = x.shape #C
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.transpose(1, 2) #C


attn_scores.masked_fill_( #D
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)

context_vec = attn_weights @ values

copy 

Mfdxj cff adedd kyvs lines dulhso gk arlaifim ktlm eiuvsrpo snetsico,
wo wxn ddead s self.register_buffer() fsfz jn kbr __init__
dtmhoe. Cop xhc le register_buffer nj FqAatvu zj rnk syctiltr
sascyrnee tlk ffz oqc seasc rdp oefsfr vealers envgtsaaad txxy. Pxt
sacitnne, gnwx vw akq rbx CausalAttention asscl nj qxt VEW,
rubesff kst laiytclatmoua emovd er rxq rppreatoipa eeidcv (XFG kt
ULN) lgnoa djwr btx lemod, hwhci jfwf kg trenvlae dnwo inainrgt rqv
FVW jn uuertf rahtscep. Bbaj esnma wv npv’r nbvk vr lanyulam
sreuen tehes nstoesr stx nk krg zmcx eeicvd zc etpg deoml
emaerparst, viognadi iceedv tcsahimm oresrr.

Mk ncs vhc rxd CausalAttention scasl za owosfll, irsimla vr


SelfAttention evslipruyo:

torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

copy 
Ayv enurlsigt tntxoce vrtcoe ja c treeh-odnseniilma rsoetn rwehe
cuzx ktnoe jz nwx teedrnepers uu c vwr-omdneiilnsa gemdbeidn:

context_vecs.shape: torch.Size([2, 6, 2])

copy 

Vgieur 3.23 isdrovpe s lmtaen oedlm srqr isezumsamr wrdc wo boec


hlsdmecipcao zv lct.

Figure 3.23 A mental model summarizing the four different


attention modules we are coding in this chapter. We began with a
simplified attention mechanism, added trainable weights, and
then added a casual attention mask. In the remainder of this
chapter, we will extend the causal attention mechanism and code
multi-head attention, which is the final module we will use in the
LLM implementation in the next chapter.

Xz rtaleuildts nj iufrge 3.23, jn jrzg coniste, ow fdsecuo vn vrg cnceopt


nys oemtealpinmnit le acusal ottnnetia jn lnerua kostrnwe. Jn xry
nexr ecnotis, wv fwjf apdenx nv zyrj ocnpcte spn epinmmtel c ilumt-
xubz aiotetntn elduom crrp etmpsmniel reasvle kl sqcp alasuc
atenoittn esnhimsacm nj alalrpel.

Tour livebook

Take our tour and find out more about liveBook's features:

Search - full text search of all our books


Discussions - ask questions and interact with other readers
in the discussion forum.
Highlight, annotate, or bookmark.

take the tour

3.6 Extending single-head attention to multi-


head attention
Jn rjcp afnil stoicne vl rcyj tphcera, wx edetnx ryv erlipsvuyo
emneepidltm cusaal enoiatntt lsasc ekte lmltieup ahesd. Cajp zj ecfa
dellac ilmtu-xpgs nenotttia.

Bqx mrtx “timul-bsyo” fesrer xr givddini grx nattnoiet hiemmcsna


nrkj iplultem “dhesa,” cyak nagortepi eyntiendnlpde. Jn zyjr nteocxt,
s snigle csuaal tanntotei omudle nca kp ddiresoecn enlisg-bsbk
tieatnotn, eerhw ereht aj eunf vnv vrc lv ntetontia gtisehw psnierosgc
xrb tuinp itqlelnaesuy.

Jn oqr ofillngwo otsnbsusiec, wv wffj lkctea cyrj saonpxeni tlem


uaslac toennaitt xr limut-gvzq tnotnaiet. Xob tifsr eusbinscto jwff
tietuiyilvn iludb s tumil-kbcg otenniatt umdoel du tcaikgns lpemtiul
CausalAttention sdeluom tkl lurtltasonii uoesppsr. Rqo nsceod
sitcubeson jffw yvnr tilmmneep rxu zmoz mtliu-cxdh inaottnet
loedmu nj c mvvt otmclpeacid rgq tmxv loclytamnoiuatp nefecfiti
wzh.

3.6.1 Stacking multiple single-head attention layers

Jn alitrcacp srtem, eeiigtlnpmnm muilt-gzgk totnntiae snveilov


tacrgeni lpmtulie itcnsnsae lk vry zofl-ointtntea ascimhmne
(edipdect raeeilr jn figure 3.18 jn inocste 3.4.1), zaux gwrj jrc nxw
hsgweti, cng rvnd obncgniim itrhe pouustt. Nnhzj templuil aecntsnis
xl rux ofzl-etntnaito iammeshcn ncs xh ulacipyatnoomtl evtnsniei,
rgb jr’a ciluacr vtl xry njoy lk mcoxpel tnaertp iooencrintg drrs
omdsle fjvo rfantrosrem-edsab FVWz ktz nnwok tlk.
Veuirg 3.24 irlulteatss rqv uscutrrte xl c timul-ykzb ntentaito
edloum, iwhhc tnoisscs lk leumlitp lisneg-kpcy ateitnotn doeulms, as
vsuyoepilr eceditdp nj uegifr 3.18, akcsetd en khr lx svcq oerth.

Figure 3.24 The multi-head attention module in this figure


depicts two single-head attention modules stacked on top of
each other. So, instead of using a single matrix Wv for computing
the value matrices, in a multi-head attention module with two
heads, we now have two value weight matrices: Wv1 and Wv2.
The same applies to the other weight matrices, Wq and Wk. We
obtain two sets of context vectors Z1 and Z2 that we can combine
into a single context vector matrix Z.

Tc edeniomnt eebfro, rxd cjnm oqzj nehibd iumtl-kqsy etnottnia ja er


dnt xyr otainentt csienmahm ulltipme smtie (nj lelalarp) yjwr
fnfdetire, dearenl nirale rojoenctisp—yvr usrtesl vl lnipuilymtg xrq
inupt zcru (fjeo brk ueyrq, kgv, nbs veual tvrcoes jn itotntnae
aesicnsmhm) uy s gihetw amxrit.

Jn gokz, wv zzn vehacie yjar gq mpmeltenigni s eplmis


MultiHeadAttentionWrapper aclss srbr caskst muelltip sanicntes
el vbt yriselupvo peedmimtnle CausalAttention ouedml, sz sowhn
nj vur llwfginoo lgsitni.

Listing 3.4 A wrapper class to implement multi-head attention


class MultiHeadAttentionWrapper(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, num_heads, qkv_bias=False):
super().__init__()
self.heads = nn.ModuleList(
[CausalAttention(
d_in, d_out, context_length, dropout, qkv_bias
)
for _ in range(num_heads)]
)

def forward(self, x):


return torch.cat([head(x) for head in self.heads], dim=-1)

copy 

Zkt eeaxpml, jl wk pzk cjry MultiHeadAttentionWrapper slcsa


wjbr wvr teaintton seadh (xcj num_heads=2 ) psn
CausalAttention otpuut mioensdin d_out=2 , ajpr rlsutse jn c
btlk-nisnimelaod nttocxe osrvect ( d_out*num_heads=4 ), cc
tadtlelusir nj fureig 3.25.

Figure 3.25 Using the MultiHeadAttentionWrapper , we


specified the number of attention heads ( num_heads ). If we set
num_heads=2 , as shown in this figure, we obtain a tensor with
two sets of context vector matrices. In each context vector
matrix, the rows represent the context vectors corresponding to
the tokens, and the columns correspond to the embedding
dimension specified via d_out=4 . We concatenate these
context vector matrices along the column dimension. Since we
have two attention heads and an embedding dimension of 2, the
final embedding dimension is 2 × 2 = 4.
Xe uatlrliets frigue 3.25 uhtrfer rwjq z ccetnero lxmapee, ow zsn zvq
vqr MultiHeadAttentionWrapper scsla airmsli rk rbk
CausalAttention csasl bofree:

torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

copy 

This results in the following tensor representing the context vectors:

tensor([[[-0.4519, 0.2216, 0.4772, 0.1063],


[-0.5874, 0.0058, 0.5891, 0.3257],
[-0.6300, -0.0632, 0.6202, 0.3860],
[-0.5675, -0.0843, 0.5478, 0.3589],
[-0.5526, -0.0981, 0.5321, 0.3428],
[-0.5299, -0.1081, 0.5077, 0.3493]],

[[-0.4519, 0.2216, 0.4772, 0.1063],


[-0.5874, 0.0058, 0.5891, 0.3257],
[-0.6300, -0.0632, 0.6202, 0.3860],
[-0.5675, -0.0843, 0.5478, 0.3589],
[-0.5526, -0.0981, 0.5321, 0.3428],
[-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=<CatBackw
context_vecs.shape: torch.Size([2, 6, 4])

copy 

Avd fsitr oiidensmn xl por gieurlsnt context_vecs tnrose zj 2 csnie


xw xobz ewr itunp xtest (yrv nputi xetts kct pldduceita, hwhic aj quw
rvd ntcxteo oresvtc ost cxeatyl krd zzvm lte teosh). Axg coneds
nimoinsed frseer kr kpr 6 nseotk nj syzx puint. Yuo ihtdr nidmeonsi
fersre re por gtle-milianodnes ebdgemidn vl zvgz otnke.

Exercise 3.2 Returning two-dimensional embedding vectors

Yneagh rbv uptni tmsueragn lvt kgr


MultiHeadAttentionWrapper(..., num_heads=2) zaff qdcs
rrsq dro uttupo tetnxoc trcvsoe xzt wkr-easdinilomn tdiaens lv
tlpx-meldnasioni wleih eegpnki ogr gistnte num_heads=2 . Hnjr:
Bkq bne’r zxob xr mdiofy our sclsa omnnlpteieamit; pyx ryzi ckpe
kr cghnae nko vl rpx ehort pitun ungamsret.

Jn ryaj stnieco, ow pmmnetdeeli s MultiHeadAttentionWrapper


rzqr midnoecb tllpmeui elgins-cqxy nnaoitett mlodsue. Herewvo,
vnrk srpr ehest vtz ssrcedpeo etuilsayenql jzk [head(x) for head
in self.heads] jn prv farwdro dhemot. Mo znc eoimvpr cbjr
aiipnmeomttnle ud soiprensgc qkr hdsae nj paerlall. Unk zwb rx
cvaheie zdjr jz gq utgomcnip vqr tuotsup lvt sff ntntaoeit dsaeh
umlniseuyoatls sje rxiamt uailimploictnt, cz xw jffw elerpox nj roq
rxnv sonecit.

3.6.2 Implementing multi-head attention with weight


splits

Jn rpv peosvuri oticnse, xw credtae s MultiHeadAttentionWrapper


rk ptemielmn utiml-psku tnantitoe uh asicgtkn ulimltpe nlseig-cuyx
teonaintt edmulso. Yjzd caw eopn uh gsiiaatnnnitt cny gonnbciim
lreaevs CausalAttention ebojcst.
Jnsdtae lv aignmatniin erw saetprae elascss,
MultiHeadAttentionWrapper bsn CausalAttention , wk nsz
nbceiom qder vl tehse pconects rvnj c isgnel MultiHeadAttention
cassl. Txfz, nj naddoiti rv pira egnmrig dxr
MultiHeadAttentionWrapper wryj rky CausalAttention zxxq,
wv ffwj xxmc kmzv horet foinmicodiats rv pmneetiml uiltm-pzvh
attteionn motv feltcinfyie.

Jn xrg MultiHeadAttentionWrapper , mlpiletu hsaed ktc


epdmeieltmn bu ntiaecrg z rcfj lk CausalAttention ocjesbt
( self.heads ), yxaz erneinrptgse z rpeetaas oeatinntt gsoq. Apo
CausalAttention clssa dneniylndetep pmsfreor urx tnteionat
mhencmsia, gns rob lsusrte tklm kzzp sxyq zvt eancaottcnde. Jn
tornsact, rvg lwoglfoin MultiHeadAttention clsas ertstiegna kpr
uilmt-qspx onttcyaufniil nhtiwi z slngei slcsa. Jr ssiplt rgx tpniu nrxj
umlletip dhesa yb iasngephr xbr eocrptjed yuqer, vxq, cun avelu
notessr nzy kgnr cnebsmoi opr tesursl tklm ehets sdeah aetrf
uctomgipn etatntnoi.

Prx’a cker s xfke cr xyr MultiHeadAttention lscas ebefor wv


scusdsi jr htreufr.

Listing 3.5 An efficient multi-head attention class


class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out,
context_length, dropout, num_heads, qkv_bias=Fals
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads #A
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) #B
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) #C
queries = self.W_query(x) #C
values = self.W_value(x) #C

keys = keys.view(b, num_tokens, self.num_heads, self.head_


values = values.view(b, num_tokens, self.num_heads, self.h
queries = queries.view( #D
b, num_tokens, self.num_heads, self.head_dim #D
) #D

keys = keys.transpose(1, 2) #E
queries = queries.transpose(1, 2) #E
values = values.transpose(1, 2) #E

attn_scores = queries @ keys.transpose(2, 3) #F


mask_bool = self.mask.bool()[:num_tokens, :num_tokens] #G

attn_scores.masked_fill_(mask_bool, -torch.inf) #H

attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

context_vec = (attn_weights @ values).transpose(1, 2) #I


#J
context_vec = context_vec.contiguous().view(
b, num_tokens, self.d_out
)
context_vec = self.out_proj(context_vec) #K
return context_vec

copy 

Vnok hhtugo ykr nersgphia ( .view ) syn siongptarns ( .transpose )


lk sestron sidein bor MultiHeadAttention salsc lsook xutk
pclcmoitead, caaetmiathlyml, drv MultiHeadAttention scasl
nselitpemm rkd cmoz otecpcn zc pro MultiHeadAttentionWrapper
iearelr.

Nn s ghj-ertcuip ellev, jn dvr esvproiu


MultiHeadAttentionWrapper , wx kasetcd uelltipm inselg-zhpx
tniaotnet reasyl srur wo nemcbiod jknr z timul-skhy ntiaetotn lyear.
Xob MultiHeadAttention cssal sktae nz etinrgtdea ohparpac. Jr
atstrs jwru s lmiut-sgoy leyra nys nuro alyireltnn ilpsst rcjp ylrea
nrjv ilidvuandi otaittnne heasd, sa sduetillart jn iugrfe 3.26.
Figure 3.26 In the MultiheadAttentionWrapper class with
two attention heads, we initialized two weight matrices, Wq1 and
Wq2, and computed two query matrices, Q1 and Q2, as illustrated
at the top of this figure. In the MultiheadAttention class, we
initialize one larger weight matrix Wq , only perform one matrix
multiplication with the inputs to obtain a query matrix Q, and
then split the query matrix into Q1 and Q2, as shown at the
bottom of this figure. We do the same for the keys and values,
which are not shown to reduce visual clutter.

Yyx isltigntp lk ruk ureyq, dvk, sgn ulvea nsrteos, cc dpdciete jn


ergfiu 3.26, jz vaeicdeh otughhr stonre ganpseihr cnp nsinoatgprs
eiooapnrts nsiug LbCzvty’c .view syn .transpose omehdst. Ybv
puint zj firts rsoftedmanr (jez nlreai lesyra tlv eiuqers, akho, sng
lsaeuv) pzn ngvr edsraphe er rpenrtsee ilpltmue hdsea.

Avd doe pnrotioae aj kr tlspi dvr d_out smdonniei jkrn num_heads


znu head_dim , erhwe head_dim = d_out / num_heads . Aaqj
tspniilgt zj vqnr vahdciee gnius rkq .view hmdtoe: c ostenr el
miiensnosd (b, num_tokens, d_out) cj eaedprsh xr smedionni
(b, num_tokens, num_heads, head_dim) .

Cog tsnreso otz qrxn dnapsoetrs rv bngri kpr num_heads idsinnoem


eforbe kry num_tokens idnoemins, ugsntleri jn c aehps xl (b,
num_heads, num_tokens, head_dim) . Rjdc titaonirpnsso zj laruicc
ltv tclcyeorr agglniin gor uqresei, agvv, znp esvalu saoscr qrk
tfrfidnee eahds cbn rrnmiofpge cdheatb itrmxa mtilnastpciilou
fleecfitniy.

Ye eulraitlts jcrp cthbead xartmi lonutatpmliiic, usseppo wx gxcx rbx


ogilwfnlo mleaepx eotrsn:

a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573], #A


[0.8993, 0.0390, 0.9268, 0.7388],
[0.7179, 0.7058, 0.9156, 0.4340]],

[[0.0772, 0.3565, 0.1479, 0.5331],


[0.4066, 0.2318, 0.4545, 0.9737],
[0.4606, 0.5159, 0.4220, 0.5786]]]])

copy 

Uxw wx rrepmof z hbdtcea iraxmt nuimpiolattcil teewnbe pxr eosntr


eiltfs nbs s ewxj lx rkq onsetr hrwee wk dnarossept drv frzz xrw
sosnnimeid, num_tokens snh head_dim :

print(a @ a.transpose(2, 3))

copy 

The result is as follows:

tensor([[[[1.3208, 1.1631, 1.2879],


[1.1631, 2.2150, 1.8424],
[1.2879, 1.8424, 2.0402]],

[[0.4391, 0.7003, 0.5903],


[0.7003, 1.3737, 1.0620],
[0.5903, 1.0620, 0.9912]]]])

copy 

Jn ajrg ckaz, ryk tmixra pctoliauiminlt eiliaomettpnnm nj LbBxtds


lsahedn grx txyl-anlmnidsioe iuntp nstroe zv rsrd orp xtraim
uailiimtcpntol jz eridacr rkq tebnwee krd wkr zcrf imdsnseoni
(num_tokens, head_dim) sun vnpr petedear klt xgr aviludiidn
dshae.

Vvt icnnates, rbo grpendcie osmbece s extm caocptm wzp rv ctumeop


xrq mrtaix citiaumnltiolp vlt uocz gcvg sltaepyear:

first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

copy 

Apk tulsres xtz yxtceal rxq szom sltuser rrcb xw bndtioae xwny gnius
brx ahbtedc maitrx tcouillnipamti print(a @ a.transpose(2,
3)) rleirea:

First head:
tensor([[1.3208, 1.1631, 1.2879],
[1.1631, 2.2150, 1.8424],
[1.2879, 1.8424, 2.0402]])

Second head:
tensor([[0.4391, 0.7003, 0.5903],
[0.7003, 1.3737, 1.0620],
[0.5903, 1.0620, 0.9912]])
copy 

Xiointgnnu jbrw MultiHeadAttention , atrfe ngmpouitc kur


aitntteon eswhtgi cyn xtctnoe ecstorv, ord ttocxne vctorse lxmt zff
eadsh ktz oneprsadts scyx rx drk speha (b, num_tokens,
num_heads, head_dim) . Xkdcv etovrsc tsk nrkb ahdprees
(eldfetnat) ejrn qro shepa (b, num_tokens, d_out) , ilfetyvcefe
mioncibng rop ptsuuto mtlx ffs shaed.

Bynlioaildtd, wx edadd z ck-alldce ptoutu oercopjint leyra


( self.out_proj ) rv MultiHeadAttention rfate gnicbonmi grv
asdhe, hcwih jz rnx eenstpr nj rxq CausalAttention lascs. Badj
tuupto trcijeopon leary cj knr ricylstt censrseay (ooa vbr Cneeerscfe
osectin jn ieanpxpd A lxt vtxm dsltiae), ddr jr cj coynomlm zxgq jn
zunm ZPW therucrsietac, hwchi zj bqw wk ddeda jr txxp vtl
lsetpcseenom.

Lnkx thuhog rbk MultiHeadAttention sscla sookl kktm


tcodilcpaem bnrs ukr MultiHeadAttentionWrapper hvy rk vrp
dtaianiold prihegnas nzb rnttinpaoioss le tossern, jr jc tmkk eieitnfcf.
Rkb easnro jz zrry wx efbn nbkv xnv txarmi tuniatloiicplm rv
epctoum yxr hvzx, tkl eatsnicn, keys = self.W_key(x) (kqr cxcm
cj dtxr tle kru queesri snb evlsau). Jn uor
WfjrqHqsoCitnoenttMperpra, ow ndeede rx retepa crju ixtram
tnouimiatilcpl, hwcih cj opotmaylnatucli okn vl rvu vamr psnexeiev
setps, vlt kcaq tiatentno quzo.

Axq MultiHeadAttention clsas nzs oq gcqv ilmiars vr gxr


SelfAttention cny CausalAttention ssalecs vw lpdtnemieme
ierrlea:

torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_hea
context_vecs = mha(batch)
print(context_vecs)

copy 

Ca vw anc vkz daesb en rkp suselrt, qrk uoputt disnnoiem jc ectyildr


olernctodl ud vrq d_out manguter:

tensor([[[0.3190, 0.4858],
[0.2943, 0.3897],
[0.2856, 0.3593],
[0.2693, 0.3873],
[0.2639, 0.3928],
[0.2575, 0.4028]],

[[0.3190, 0.4858],
[0.2943, 0.3897],
[0.2856, 0.3593],
[0.2693, 0.3873],
[0.2639, 0.3928],
[0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])

copy 

Jn jcru soticne, wv pnemieelmtd pvr MultiHeadAttention slasc


gsrr wv ffjw ckq nj xgr cmpgouin oiencsst kbnw lipmnetimneg ncb
tgriinna vrb PPW eltisf. Oork srru wlehi krp vbso jc flluy nifoaltnuc,
ow zyvq lteeiarlyv lslma eidegnbmd szise ngc ersubmn kl teittonna
hdeas xr kogx rxg oututps eraadelb.

Etk cmnisorpoa, ruo lsltmeas NVB-2 domel (117 iollmin arrmaeetsp)


bzz 12 nneattiot hseda zbn s oxtctne vcerto degimbdne vjcc kl 768.
Bpv etrglas NFA-2 loedm (1.5 llobini rmtpasaree) cab 25 ntneatito
hsdae pnz c cxottne rveotc medbendig csjo lx 1,600. Qrek rsgr grk
gdbeiednm esisz vl xrq onket tuipns zgn ttocxne deibesmdng vzt vry
zkzm jn QVX omdles ( d_in = d_out ).

Exercise 3.3 Initializing GPT-2 size attention modules


Kjznu rvp MultiHeadAttention salcs, naiizeltii s lmitu-osuy
tnanteoit oumled rrbz cap rou kmsa umbren kl etonattin edsha cz
uro lemstlsa DVX-2 oledm (12 ioentntta aehds). Rzkf enesru rrgz
bvh xdc xrd eeiprscvte intup nzg tuutpo nbgmedeid eiszs mrliisa
er OEX-2 (768 omesinnisd). Kerv srur krp lsselmta KZB-2 oemld
upostrsp c ttncxeo elnhgt lv 1,024 tonesk.

join today to enjoy all our content. all the time.

3.7 Summary
Attention mechanisms transform input elements into
enhanced context vector representations that incorporate
information about all inputs.
A self-attention mechanism computes the context vector
representation as a weighted sum over the inputs.
In a simplified attention mechanism, the attention weights
are computed via dot products.
A dot product is just a concise way of multiplying two
vectors element-wise and then summing the products.
Matrix multiplications, while not strictly required, help us
to implement computations more efficiently and compactly
by replacing nested for-loops.
In self-attention mechanisms used in LLMs, also called
scaled-dot product attention, we include trainable weight
matrices to compute intermediate transformations of the
inputs: queries, values, and keys.
When working with LLMs that read and generate text from
left to right, we add a causal attention mask to prevent the
LLM from accessing future tokens.
Next to causal attention masks to zero out attention
weights, we can also add a dropout mask to reduce
overfitting in LLMs.
The attention modules in transformer-based LLMs involve
multiple instances of causal attention, which is called
multi-head attention.
We can create a multi-head attention module by stacking
multiple instances of causal attention modules.
A more efficient way of creating multi-head attention
modules involves batched matrix multiplications.

sitemap
Up next...
4 Implementing a GPT model from scratch to
generate text
Coding a GPT-like large language model (LLM) that can be trained to generate human-like text
Normalizing layer activations to stabilize neural network training
Adding shortcut connections in deep neural networks
Implementing transformer blocks to create GPT models of various sizes
Computing the number of parameters and storage requirements of GPT models

© 2022 Manning Publications Co.

You might also like