Coding Attention Mechanisms
Coding Attention Mechanisms
'''
'''
The above picture is of the Bahadanu attention mechanism where the
decoder RNN has access to input weights along with attention weights
which indicate how important
each input token is
'''
'''
Self-attention is a mechanism that allows each position in the input
sequence to attend to all positions in the same sequence when
computing the representation of a sequence.
Self-attention is a key component of contemporary LLMs based on the
transformer architecture, such as the GPT series.
'''
'''
The context vector for xi stores info about the element itself wrt all
the input tokens in the sequence.
'''
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)
'''
The first step of implementing self-attention is to compute the
intermediate values ω, referred to as attention scores, as illustrated
in figure 3.8.
(Please note that figure 3.8 displays the values of the preceding
inputs tensor in a truncated version; for example, 0.87 is truncated
to 0.8 due to spatial constraints.
In this truncated version, the embeddings of the words “journey” and
“starts” may appear similar by random chance.)
'''
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)
'''
Now we normalize these attention scores we computed by taking a dot
product.
'''
#Normalization 1
'''
In practice, it’s more common and advisable to use the softmax
function for normalization.
This approach is better at managing extreme values and offers more
favorable gradient properties during training.
The following is a basic implementation of the softmax function for
normalizing the attention scores:
'''
#Normalization 2
def softmax_naive(x):
return torch.exp(x) / torch.exp(x).sum(dim=0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())
'''
Now that we have computed the normalized attention weights, we are
ready for the final step illustrated in figure 3.10:
calculating the context vector z(2) by multiplying the embedded input
tokens, x(i), with the corresponding attention weights and then
summing the resulting vectors.
'''
query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
#print(context_vec_2)
context_vec_2 += attn_weights_2[i]*x_i
#print(x_i)
#print(attn_weights_2[i])
print(context_vec_2)
'''
The complete Algorithm:
1. Compute Attention scores - compute attention scores as dot
product between input embedding vectors
2. Compute Attention weights - Normalize the attention scores using
softmax
3. Compute Context vectors - Multiply respective weights to the
embedding vectors and take their combined sum to construct Context
vector
'''
#Method 1 using for loops
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
for j, x_j in enumerate(inputs):
attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)
'''
We will implement the self-attention mechanism step by step by
introducing the three trainable weight matrices Wq, Wk, and Wv.
These three matrices are used to project the embedded input tokens,
x(i), into query, key, and value vectors, as illustrated in figure
3.14.
'''
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
'''
Note that in GPT-like models, the input and output dimensions are
usually the same, but for illustration purposes,
to better follow the computation, we choose different input (d_in=3)
and output (d_out=2) dimensions here.
'''
{"type":"string"}
#We set the requires_grad = False to reduce clutter but have to set it
True at the time of model training
W_query = torch.nn.Parameter(torch.rand(d_in, d_out),
requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out),
requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out),
requires_grad=False)
torch.nn.Parameter():
1.Function: torch.nn.Parameter()
2.Description: This function wraps a tensor and marks it as a
parameter, which means it is a tensor that you want to optimize (i.e.,
learn) during training.
3.Parameters are special types of tensors that are automatically
added to the list of parameters of a torch.nn.Module when assigned as
attributes.
4.Parameter: The tensor generated by torch.rand(d_in, d_out) in
this case.
5.Purpose: Wrapping a tensor with torch.nn.Parameter() allows it to
be considered a model parameter, making it available for optimization.
tensor([0.4306, 1.4551])
'''
Note that in the weight matrices W, the term “weight” is short for
“weight parameters,” the values of a neural network that are optimized
during training.
This is not to be confused with the attention weights. As we already
saw in the previous section,
attention weights determine the extent to which a context vector
depends on the different parts of the input—i.e.,
to what extent the network focuses on different parts of the input.
In summary, weight parameters are the fundamental, learned
coefficients that define the network’s connections, while attention
weights are dynamic, context-specific values.
'''
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2/d_k**0.5, dim = -1) # we
also divide the attn_scores by the square root of embedding dimension
size to reduce the attn_score and inc gradient step
print(attn_weights_2)
tensor([0.3061, 0.8210])
'''
Why query, key, and value?
The terms “key,” “query,” and “value” in the context of attention
mechanisms are borrowed from the domain of information retrieval and
databases,
where similar concepts are used to store, search, and retrieve
information.
A query is analogous to a search query in a database. It represents
the current item (e.g., a word or token in a sentence) the model
focuses on or tries to understand.
The query is used to probe the other parts of the input sequence to
determine how much attention to pay to them.
The key is like a database key used for indexing and searching. In the
attention mechanism, each item in the input sequence (e.g., each word
in a sentence) has an associated key.
These keys are used to match the query.
The value in this context is similar to the value in a key-value pair
in a database. It represents the actual content or representation of
the input items.
Once the model determines which keys (and thus which parts of the
input) are most relevant to the query (the current focus item), it
retrieves the corresponding values.
'''
class SelfAttention_v1(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out))
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))
tensor([[0.2996, 0.8053],
[0.3061, 0.8210],
[0.3058, 0.8203],
[0.2948, 0.7939],
[0.2927, 0.7891],
[0.2990, 0.8040]], grad_fn=<MmBackward0>)
# Improving SelfAttention_v1 using pytorch nn.Linear since its just
matrix multiplication in case bias is disabled.
class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))
tensor([[-0.0739, 0.0713],
[-0.0748, 0.0703],
[-0.0749, 0.0702],
[-0.0760, 0.0685],
[-0.0763, 0.0679],
[-0.0754, 0.0693]], grad_fn=<MmBackward0>)
'''
Note that SelfAttention_v1 and SelfAttention_v2 give different outputs
because they use different initial weights
for the weight matrices since nn.Linear uses a more sophisticated
weight initialization scheme.
'''
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
print(attn_weights)
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
masked_simple = attn_weights*mask_simple
print(masked_simple)
#Renormalizing
'''
It might seem there is slight information leakage since we populate
the matrix completely before zeroing the upper half.
However, the key insight is that when we renormalize the attention
weights after masking,
what we’re essentially doing is recalculating the softmax over a
smaller subset (since masked positions don’t contribute to the softmax
value).
'''
This method helps prevent overfitting by ensuring that a model does
not become overly reliant on any specific set of hidden layer units.
It’s important to emphasize that dropout is only used during training
and is disabled afterward.
Here we will apply the dropout mask after computing the attention
weights, as illustrated in figure 3.22, because it’s the more common
variant in practice.
'''
# Implementing dropout
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) #For gpt like models, this is more
like 0.1 or 0.2
example = torch.ones(6, 6)
print(dropout(example))
'''
When applying dropout to an attention weight matrix with a rate of
50%, half of the elements in the matrix are randomly set to zero.
To compensate for the reduction in active elements, the values of the
remaining elements in the matrix are scaled up by a factor of 1/0.5 =
2.
This scaling is crucial to maintain the overall balance of the
attention weights,
ensuring that the average influence of the attention mechanism remains
consistent during both the training and inference phases.
'''
torch.manual_seed(123)
print(dropout(attn_weights))
torch.Size([2, 6, 3])
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout,
qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = torch.nn.Dropout(dropout)
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
'''
While all added code lines should be familiar from previous sections,
we now added a self.register_buffer() call in the __init__ method.
The use of register_buffer in PyTorch is not strictly necessary for
all use cases but offers several advantages here.
For instance, when we use the CausalAttention class in our LLM,
buffers are automatically moved to the appropriate device (CPU or GPU)
along with our model,
which will be relevant when training the LLM in future chapters.
This means we don’t need to manually ensure these tensors are on the
same device as your model parameters, avoiding device mismatch errors.
'''
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
'''
The term “multi-head” refers to dividing the attention mechanism into
multiple “heads,” each operating independently.
In this context, a single causal attention module can be considered
single-head attention, where there is only one set of attention
weights processing the input sequentially.
'''
# First we will implement multi head attention by stacking multiple
causal attention modules and then implement the same multi head
attention module in a more complicated
# but efficient way
'''
In practical terms, implementing multi-head attention involves
creating multiple instances of the self-attention mechanism,
each with its own weights, and then combining their outputs. Using
multiple instances of the self-attention mechanism can be
computationally intensive,
but it’s crucial for the kind of complex pattern recognition that
models like transformer-based LLMs are known for.
'''
class MultiHeadAttentionWrapper(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, num_heads, qkv_bias=False):
super().__init__()
self.heads = nn.ModuleList(
[CausalAttention(
d_in, d_out, context_length, dropout, qkv_bias
)
for _ in range(num_heads)]
)
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
'''
The first dimension of the resulting context_vecs tensor is 2 since we
have two input texts
(the input texts are duplicated, which is why the context vectors are
exactly the same for those).
The second dimension refers to the 6 tokens in each input. The third
dimension refers to the four-dimensional embedding of each token.
'''
'''
In this section, we implemented a MultiHeadAttentionWrapper that
combined multiple single-head attention modules.
However, note that these are processed sequentially via [head(x) for
head in self.heads] in the forward method. We can improve this
implementation by processing the heads in parallel.
One way to achieve this is by computing the outputs for all attention
heads simultaneously via matrix multiplication, as we will explore in
the next section.
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out,
context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = context_vec.contiguous().view(
b, num_tokens, self.d_out
)
context_vec = self.out_proj(context_vec)
return context_vec
'''
The key operation is to split the d_out dimension into num_heads and
head_dim, where head_dim = d_out / num_heads.
This splitting is then achieved using the .view method: a tensor of
dimensions (b, num_tokens, d_out) is reshaped to dimension (b,
num_tokens, num_heads, head_dim).
The tensors are then transposed to bring the num_heads dimension
before the num_tokens dimension, resulting in a shape of (b,
num_heads, num_tokens, head_dim).
This transposition is crucial for correctly aligning the queries,
keys, and values across the different heads and performing batched
matrix multiplications efficiently.
'''
'''
In this case, the matrix multiplication implementation in PyTorch
handles the four-dimensional input tensor so that
the matrix multiplication is carried out between the two last
dimensions (num_tokens, head_dim) and then repeated for the individual
heads.
For instance, the preceding becomes a more compact way to compute the
matrix multiplication for each head separately:
'''
'''
The attention weights are used to compute a weighted sum of the
values, resulting in the context vectors.
The transpose(1, 2) operation reorders the dimensions back to (b,
num_tokens, num_heads, head_dim).
'''
Why Transpose Instead of Stacking Causal Attention Heads?
######################################################################
##################################
########## End of
Chapter 3 ##########
######################################################################
##################################