The Annotated Transformer: Alexander M. Rush
The Annotated Transformer: Alexander M. Rush
The Annotated Transformer: Alexander M. Rush
Alexander M. Rush
[email protected]
Harvard University
1 Introduction
Replication of published results remains a
challenging issue in open-source NLP. When
a new paper is published with major im-
provements, it is common for many mem-
bers of the community to independently re-
produce the numbers experimentally, which
is often a struggle. Practically this makes it
difficult to improve scores, but also impor-
tantly it is a pedagogical issue if students can-
not reproduce results from scientific publica-
tions.
The recent turn towards deep learning has
exerbated this issue. New models require
extensive hyperparameter tuning and long
training times. Small mistakes can cause ma-
1 Presented at http://nlp.seas.harvard.
jor issues. Fortunately though, new toolsets
edu/2018/04/03/attention.html with source
have made it possible to write simpler more code at https://github.com/harvardnlp/
mathematically declarative code. annotated-transformer
52
Proceedings of Workshop for NLP Open Source Software, pages 52–60
Melbourne, Australia, July 20, 2018.
2018
c Association for Computational Linguistics
2 Background is auto-regressive (Graves, 2013), consum-
ing the previously generated symbols as ad-
The goal of reducing sequential computa- ditional input when generating the next.
tion also forms the foundation of the Extended
class EncoderDecoder(nn.Module):
Neural GPU, ByteNet and ConvS2S, all of """
A standard Encoder-Decoder architecture.
which use convolutional neural networks as Base for this and many other models.
"""
basic building block, computing hidden rep- def __init__(self, encoder, decoder, src_embed,
tgt_embed, generator):
resentations in parallel for all input and out- super(EncoderDecoder, self).__init__()
self.encoder = encoder
put positions. In these models, the number self.decoder = decoder
self.src_embed = src_embed
of operations required to relate signals from self.tgt_embed = tgt_embed
self.generator = generator
two arbitrary input or output positions grows
def forward(self, src, tgt, src_mask, tgt_mask):
in the distance between positions, linearly "Take in and process masked src and target sequences."
return self.decode(self.encode(src, src_mask),
for ConvS2S and logarithmically for ByteNet. src_mask,
tgt, tgt_mask)
This makes it more difficult to learn depen-
dencies between distant positions. In the def encode(self, src, src_mask):
return self.encoder(self.src_embed(src), src_mask)
Transformer this is reduced to a constant def decode(self, memory, src_mask, tgt, tgt_mask):
number of operations, albeit at the cost of return self.decoder(self.tgt_embed(tgt), memory,
src_mask, tgt_mask)
reduced effective resolution due to averag-
ing attention-weighted positions, an effect we class Generator(nn.Module):
"Define standard linear + softmax generation step."
counteract with Multi-Head Attention. def __init__(self, d_model, vocab):
super(Generator, self).__init__()
Self-attention, sometimes called intra- self.proj = nn.Linear(d_model, vocab)
3 Model Architecture
class Encoder(nn.Module):
"Core encoder is a stack of N layers"
def __init__(self, layer, N):
3.1.2 Decoder
super(Encoder, self).__init__()
self.layers = clones(layer, N) The decoder is also composed of a stack of
self.norm = LayerNorm(layer.size)
N = 6 identical layers.
def forward(self, x, mask):
"Pass the input/mask through each layer in turn."
for layer in self.layers: class Decoder(nn.Module):
x = layer(x, mask) "Generic N layer decoder with masking."
return self.norm(x) def __init__(self, layer, N):
super(Decoder, self).__init__()
self.layers = clones(layer, N)
We employ a residual connection (He et al., self.norm = LayerNorm(layer.size)
2016) around each of the two sub-layers, fol- def forward(self, x, memory, src_mask, tgt_mask):
for layer in self.layers:
lowed by layer normalization (Ba et al., 2016). x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
class LayerNorm(nn.Module):
"Construct a layernorm module (See citation for details)." In addition to the two sub-layers in each
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__() encoder layer, the decoder inserts a third
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features)) sub-layer, which performs multi-head atten-
self.eps = eps
tion over the output of the encoder stack.
def forward(self, x):
mean = x.mean(-1, keepdim=True) Similar to the encoder, we employ residual
std = x.std(-1, keepdim=True)
return (self.a_2 * (x - mean) / connections around each of the sub-layers,
(std + self.eps) + self.b_2)
followed by layer normalization.
That is, the output of each sub-layer class DecoderLayer(nn.Module):
is LayerNorm( x + Sublayer( x )), where "Decoder calls self-attn, src-attn, and feed forward."
def __init__(self, size, self_attn,
Sublayer( x ) is the function implemented src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
by the sub-layer itself. We apply dropout self.self_attn = self_attn
self.src_attn = src_attn
(Srivastava et al., 2014) to the output of each self.feed_forward = feed_forward
sublayer = SublayerConnection(size, dropout)
sub-layer, before it is added to the sub-layer self.sublayer = clones(sublayer, 3)
self.size = size
input and normalized. def forward(self, x, memory, s_mask, t_mask):
To facilitate these residual connections, all "Follow Figure 1 (right) for connections."
m = memory
sub-layers in the model, as well as the em- x = self.sublayer[0](x, lambda x:
self.self_attn(x, x, x, t_mask))
bedding layers, produce outputs of dimension x = self.sublayer[1](x, lambda x:
self.src_attn(x, m, m, s_mask))
dmodel = 512. return self.sublayer[2](x, self.feed_forward)
54
3.1.3 Attention two are similar in theoretical complexity, dot-
An attention function can be described as product attention is much faster and more
mapping a query and a set of key-value pairs space-efficient in practice, since it can be im-
to an output, where the query, keys, values, plemented using highly optimized matrix mul-
and output are all vectors. The output is com- tiplication code.
puted as a weighted sum of the values, where While for small values of dk the two mech-
the weight assigned to each value is com- anisms perform similarly, additive attention
puted by a compatibility function of the query outperforms dot product attention without
with the corresponding key. scaling for larger values of dk (Britz et al.,
We call our particular attention "Scaled 2017). We suspect that for large values of
Dot-Product Attention". The input consists of dk , the dot products grow large in magni-
queries and keys of dimension dk , and values tude, pushing the softmax function into re-
of dimension dv . We compute the dot prod- gions where it has extremely small gradients
ucts (To illustrate why the dot products get large,
√ of the query with all keys, divide each by
dk , and apply a softmax function to obtain assume that the components of q and k are
the weights on the values. independent random variables with mean 0
and variance 1. Then their dot product, q · k =
d
∑i=k 1 qi k i , has mean 0 and variance dk .). To
counteract this effect, we scale the dot prod-
ucts by √1d .
k
QK T
Attention( Q, K, V ) = softmax( √ )V
dk
Multi-head attention allows the model to
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'" jointly attend to information from different
d_k = query.size(-1)
key_t = key.transpose(-2, -1) representation subspaces at different posi-
scores = torch.matmul(query, key_t) / math.sqrt(d_k)
if mask is not None: tions. With a single attention head, averaging
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1) inhibits this.
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
MultiHead( Q, K, V ) = Concat(head1 , ..., headh )W O
The two most commonly used attention
where headi = Attention( QWiQ , KWiK , VWiV )
functions are additive attention (Bahdanau
et al., 2014), and dot-product (multiplicative) Where the projections are parameter ma-
attention. Dot-product attention is identical to trices WiQ ∈ Rdmodel ×dk , WiK ∈ Rdmodel ×dk ,
our algorithm, except for the scaling factor of WiV ∈ Rdmodel ×dv and W O ∈ Rhdv ×dmodel . In
√1 . Additive attention computes the com- this work we employ h = 8 parallel attention
dk
patibility function using a feed-forward net- layers, or heads. For each of these we use
work with a single hidden layer. While the dk = dv = dmodel /h = 64. Due to the reduced
55
dimension of each head, the total computa- the usual learned linear transformation and
tional cost is similar to that of single-head at- softmax function to convert the decoder out-
tention with full dimensionality. put to predicted next-token probabilities. In
our model, we share the same weight ma-
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1): trix between the two embedding layers and
"Take in model size and number of heads."
super(MultiHeadedAttention, self).__init__() the pre-softmax linear transformation, similar
assert d_model % h == 0
# We assume d_v always equals d_k to (Press and Wolf, 2016). In the embedding
√
self.d_k = d_model // h
self.h = h layers, we multiply those weights by dmodel .
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout) class Embeddings(nn.Module):
def __init__(self, d_model, vocab):
def forward(self, query, key, value, mask=None): super(Embeddings, self).__init__()
"Implements Figure 2" self.lut = nn.Embedding(vocab, d_model)
if mask is not None: self.d_model = d_model
# Same mask applied to all h heads.
mask = mask.unsqueeze(1) def forward(self, x):
nb = query.size(0) return self.lut(x) * math.sqrt(self.d_model)
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = [
l(x).view(nb, -1, self.h, self.d_k).transpose(1, 2) 3.4 Positional Encoding
for l, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch. Since our model contains no recurrence and
x, self.attn = attention(query, key, value, mask=mask,
dropout=self.dropout) no convolution, in order for the model to make
# 3) "Concat" using a view and apply a final linear. use of the order of the sequence, we must in-
x = x.transpose(1, 2).contiguous().view(
nb, -1, self.h * self.d_k) ject some information about the relative or ab-
return self.linears[-1](x)
solute position of the tokens in the sequence.
3.2 Position-wise Feed-Forward Networks To this end, we add "positional encodings" to
the input embeddings at the bottoms of the
In addition to attention sub-layers, each of encoder and decoder stacks. The positional
the layers in our encoder and decoder con- encodings have the same dimension dmodel
tains a fully connected feed-forward network, as the embeddings, so that the two can be
which is applied to each position separately summed. There are many choices of posi-
and identically. This consists of two linear tional encodings, learned and fixed (Gehring
transformations with a ReLU activation in be- et al., 2017).
tween. In this work, we use sine and cosine func-
tions of different frequencies:
FFN( x ) = max(0, xW1 + b1 )W2 + b2
While the linear transformations are the same PE( pos,2i) = sin( pos/100002i/dmodel )
across different positions, they use different
parameters from layer to layer. Another way PE( pos,2i+1) = cos( pos/100002i/dmodel )
of describing this is as two convolutions with
kernel size 1. The dimensionality of input and
output is dmodel = 512, and the inner-layer has where pos is the position and i is the dimen-
dimensionality d f f = 2048. sion. That is, each dimension of the posi-
tional encoding corresponds to a sinusoid.
class PositionwiseFeedForward(nn.Module): The wavelengths form a geometric progres-
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1): sion from 2π to 10000 · 2π. We chose this
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff) function because we hypothesized it would
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout) allow the model to easily learn to attend by
def forward(self, x):
relative positions, since for any fixed offset k,
return self.w_2(self.dropout(F.relu(self.w_1(x))))
PE pos+k can be represented as a linear func-
tion of PE pos .
3.3 Embeddings and Softmax
In addition, we apply dropout to the sums of
Similarly to other sequence transduction the embeddings and the positional encodings
models, we use learned embeddings to con- in both the encoder and decoder stacks. For
vert the input tokens and output tokens to the base model, we use a rate of Pdrop = 0.1.
vectors of dimension dmodel . We also use
56
class PositionalEncoding(nn.Module): self.src = src
"Implement the PE function." self.src_mask = (src != pad).unsqueeze(-2)
def __init__(self, d_model, dropout, max_len=5000): if trg is not None:
super(PositionalEncoding, self).__init__() self.trg = trg[:, :-1]
self.dropout = nn.Dropout(p=dropout) self.trg_y = trg[:, 1:]
self.trg_mask = self.make_std_mask(self.trg, pad)
# Compute the positional encodings once in log space. self.ntokens = (self.trg_y != pad).data.sum()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1) @staticmethod
div_term = torch.exp(torch.arange(0, d_model, 2) * def make_std_mask(tgt, pad):
-(math.log(10000.0) / d_model)) "Create a mask to hide padding and future words."
pe[:, 0::2] = torch.sin(position * div_term) tgt_mask = (tgt != pad).unsqueeze(-2)
pe[:, 1::2] = torch.cos(position * div_term) tgt_mask = tgt_mask & Variable(
pe = pe.unsqueeze(0) subsequent_mask(tgt.size(-1))
self.register_buffer('pe', pe) .type_as(tgt_mask.data))
return tgt_mask
def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)],
requires_grad=False)
return self.dropout(x)
4.2 Training Loop
def run_epoch(data_iter, model, loss_compute):
"Standard Training and Logging Function"
plt.figure(figsize=(15, 5)) start = time.time()
pe = PositionalEncoding(20, 0) total_tokens = 0
y = pe.forward(Variable(torch.zeros(1, 100, 20))) total_loss = 0
plt.plot(np.arange(100), y[0, :, 4:8].data.numpy()) tokens = 0
plt.legend(["dim %d" % p for p in [4, 5, 6, 7]]) for i, batch in enumerate(data_iter):
None out = model.forward(batch.src, batch.trg,
batch.src_mask, batch.trg_mask)
loss = loss_compute(out, batch.trg_y, batch.ntokens)
total_loss += loss
total_tokens += batch.ntokens
tokens += batch.ntokens
if i % 50 == 1:
elapsed = time.time() - start
print("Epoch Step: %d Loss: %f Tokens / Sec: %f" %
(i, loss / batch.ntokens, tokens / elapsed))
start = time.time()
tokens = 0
return total_loss / total_tokens
our models.
4.4 Hardware and Schedule
4.1 Batches and Masking We trained our models on one machine with 8
class Batch: NVIDIA P100 GPUs. For our base models us-
"Batch of data with mask for training."
def __init__(self, src, trg=None, pad=0): ing the hyperparameters described through-
57
out the paper, each training step took about
0.4 seconds. We trained the base models for
a total of 100,000 steps or 12 hours. For our
big models, step time was 1.0 seconds. The
big models were trained for 300,000 steps
(3.5 days).
4.5 Optimizer
58
max_len=60,
start_symbol=TGT.stoi["<s>"])
print("Translation:", end="\t")
trans = "<s> "
crit = LabelSmoothing(5, 0, 0.1) for i in range(1, out.size(1)):
def loss(x): sym = TGT.itos[out[0, i]]
d = x + 3 * 1 if sym == "</s>":
predict = torch.FloatTensor([[0, x / d, 1 / d, break
1 / d, 1 / d]]) trans += sym + " "
return crit(Variable(predict.log()), print(trans)
Variable(torch.LongTensor([1]))).data[0]
plt.plot(np.arange(1, 100),
[loss(x) for x in range(1, 100)]) 5.2 Attention Visualization
None
tgt_sent = trans.split()
def draw(data, x, y, ax):
seaborn.heatmap(data,
xticklabels=x, square=True,
yticklabels=y, vmin=0.0, vmax=1.0,
cbar=False, ax=ax)
59
References
Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E
Hinton. 2016. Layer normalization. arXiv
preprint arXiv:1607.06450.
60