Transformer Attention Mechanism in NLP
Last Updated :
23 Jul, 2025
Transformer model is a type of neural network architecture designed to handle sequential data primarily for tasks such as language translation, text generation and many more. Unlike traditional recurrent neural networks (RNNs) or convolutional neural networks (CNNs), Transformers uses attention mechanism to capture relationships between all words in a sentence regardless of their distance from each other.
The attention mechanism is a technique that allows models to focus on specific parts of the input sequence when producing each element of the output sequence. It assigns different weights to different input elements enabling the model to prioritize certain information over others. This is particularly useful in tasks like language translation where the meaning of a word often depends on its context. In this article we will learn about different types of Transformer’s attention mechanism.
1. Scaled Dot-Product Attention
The Scaled Dot-Product Attention is the fundamental building block of the Transformer's attention mechanism. It involves three main components: queries (Q), keys (K) and values (V). The attention score is computed as the dot product of the query and key vectors, scaled by the square root of the dimension of the key vectors. This score is then passed through a softmax function to obtain the attention weights which are used to compute a weighted sum of the value vectors.
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V
where d_k is the dimension of the key vectors.
2. Multi-Head Attention
Multi-Head Attention enhances the model's ability to focus on different parts of the input sequence simultaneously. It involves multiple attention heads each with its own set of query, key and value matrices. The outputs of these heads are concatenated and linearly transformed to produce the final output. This allows the model to capture different features and dependencies in the input sequence.
Formula:
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h) W^O
where each \text{Attention}\big(Q W_i^{Q}, \; K W_i^{K}, \; V W_i^{V}\big) and W^O is the output matrix.
3. Self-Attention
Self-Attention is also known as intra-attention which allows the model to consider different positions of the same sequence when computing the representation of a word. In the context of the Transformer, self-attention is applied in both the encoder and decoder layers. It enables the model to capture long-range dependencies and relationships within the input sequence.
4. Encoder-Decoder Attention
Encoder-Decoder Attention also known as cross-attention, is used in the decoder layers of the Transformer. It allows the decoder to focus on relevant parts of the input sequence (encoded by the encoder) when generating each word of the output sequence. This type of attention ensures that the decoder has access to the entire input sequence, helping it produce more accurate and contextually appropriate translations.
5. Causal or Masked Self-Attention
Causal or Masked Self-Attention is used in the decoder to ensure that the prediction for a given position only depends on the known outputs at positions before it. This is crucial for tasks like language modeling where future tokens should not be visible during training. The attention scores for future tokens are masked out, ensuring that the model cannot look ahead.
Formula:
\text{MaskedAttention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T + M}{\sqrt{d_k}}\right) V
where M is the mask matrix with - \infty in positions that should be masked.
The attention mechanism in Transformers offers several advantages:
- Parallel Processing: Unlike RNNs, Transformers can process all words in a sequence simultaneously, significantly reducing training time.
- Long-Range Dependencies: The attention mechanism can capture relationships between distant words, addressing the limitations of traditional models that struggle with long-range dependencies.
- Scalability: Transformers can handle larger datasets and complex tasks due to their scalable architecture.
Step 1: Import Necessary Libraries
First import TensorFlow and other required libraries.
Python
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense, Dropout, LayerNormalization, Embedding
import numpy as np
Step 2: Scaled Dot-Product Attention
The Scaled Dot-Product Attention mechanism is the foundational building block of the attention mechanism. It computes the attention scores based on the dot product of the query (Q) and key (K) vectors, scales it by the square root of the dimension of the key vectors and applies a softmax function to calculate the attention weights.
- tf.matmul(q, k, transpose_b=True): Computes dot product between queries and transposed keys.
- tf.math.sqrt(dk): Calculates the square root of key dimension for scaling.
- scaled_attention_logits += (mask * -1e9): Applies mask by adding large negative values to ignored positions.
- tf.nn.softmax(..., axis=-1): Converts scores into attention probabilities.
- tf.matmul(attention_weights, v): Applies attention weights to values to get the final output.
Python
class ScaledDotProductAttention(Layer):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def call(self, q, k, v, mask=None):
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
return output, attention_weights
Step 3: Multi-Head Attention
Define the MultiHeadAttention
class. This class uses multiple attention heads to focus on different parts of the sequence simultaneously.
- Dense(d_model) layers (self.wq, self.wk, self.wv): Learnable linear projections to transform inputs into queries (Q), keys (K) and values (V).
- split_heads(x, batch_size): Reshapes and transposes input tensor to separate heads for parallel attention computation.
- ScaledDotProductAttention()(q, k, v, mask): Calculates attention scores and applies masking for each head independently.
- tf.transpose and tf.reshape: Rearranges and concatenates multi-head outputs back into a single tensor.
- self.dense(concat_attention): Final linear layer to combine multi-head attention outputs into original d_model dimensions.
Python
class MultiHeadAttention(Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0
self.depth = d_model // num_heads
self.wq = Dense(d_model)
self.wk = Dense(d_model)
self.wv = Dense(d_model)
self.dense = Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention, attention_weights = ScaledDotProductAttention()(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
output = self.dense(concat_attention)
return output, attention_weights
Step 4: Testing the Encoder Layer and Encoder
To ensure that our Scaled Dot-Product Attention and Multi-Head Attention work correctly let’s test them with some random inputs:
- tf.random.uniform(...): Generates random input tensors for queries (q), keys (k) and values (v).
- ScaledDotProductAttention()(q, k, v, mask): Computes scaled dot-product attention output and attention weights.
- MultiHeadAttention(d_model, num_heads): Initializes multi-head attention layer with given model dimension and number of heads.
- multi_head_attention(v, k, q, mask): Applies multi-head attention, splitting inputs into heads, computing attention in parallel and concatenating results.
Python
q = tf.random.uniform((64, 50, 512))
k = tf.random.uniform((64, 50, 512))
v = tf.random.uniform((64, 50, 512))
mask = None
attention_output, attention_weights = ScaledDotProductAttention()(q, k, v, mask)
print("Attention Output Shape:", attention_output.shape)
multi_head_attention = MultiHeadAttention(d_model=512, num_heads=8)
output, attn_weights = multi_head_attention(v, k, q, mask)
print("Multi-Head Attention Output Shape:", output.shape)
Output:
Attention Output Shape: (64, 50, 512)
Multi-Head Attention Output Shape: (64, 50, 512)
- The Attention Output Shape of
(64, 50, 512)
indicates that for a batch of 64 sequences each of length 50 the output has a depth of 512 representing context-aware word embeddings. - The Multi-Head Attention Output Shape of
(64, 50, 512)
is similar but with the addition of multiple attention heads allowing the model to capture different relationships in the sequence simultaneously providing a richer representation.
Transformer’s attention mechanism is a key innovation that allows it to outperform traditional models on many NLP tasks. By using different types of attention like Scaled Dot-Product, Multi-Head, Self-Attention, Encoder-Decoder and Causal Attention the model can efficiently capture complex relationships between words in a sequence.
Similar Reads
Natural Language Processing (NLP) Tutorial Natural Language Processing (NLP) is a branch of Artificial Intelligence (AI) that helps machines to understand and process human languages either in text or audio form. It is used across a variety of applications from speech recognition to language translation and text summarization.Natural Languag
5 min read
Introduction to NLP
Natural Language Processing (NLP) - OverviewNatural Language Processing (NLP) is a field that combines computer science, artificial intelligence and language studies. It helps computers understand, process and create human language in a way that makes sense and is useful. With the growing amount of text data from social media, websites and ot
9 min read
NLP vs NLU vs NLGNatural Language Processing(NLP) is a subset of Artificial intelligence which involves communication between a human and a machine using a natural language than a coded or byte language. It provides the ability to give instructions to machines in a more easy and efficient manner. Natural Language Un
3 min read
Applications of NLPAmong the thousands and thousands of species in this world, solely homo sapiens are successful in spoken language. From cave drawings to internet communication, we have come a lengthy way! As we are progressing in the direction of Artificial Intelligence, it only appears logical to impart the bots t
6 min read
Why is NLP important?Natural language processing (NLP) is vital in efficiently and comprehensively analyzing text and speech data. It can navigate the variations in dialects, slang, and grammatical inconsistencies typical of everyday conversations. Table of Content Understanding Natural Language ProcessingReasons Why NL
6 min read
Phases of Natural Language Processing (NLP)Natural Language Processing (NLP) helps computers to understand, analyze and interact with human language. It involves a series of phases that work together to process language and each phase helps in understanding structure and meaning of human language. In this article, we will understand these ph
7 min read
The Future of Natural Language Processing: Trends and InnovationsThere are no reasons why today's world is thrilled to see innovations like ChatGPT and GPT/ NLP(Natural Language Processing) deployments, which is known as the defining moment of the history of technology where we can finally create a machine that can mimic human reaction. If someone would have told
7 min read
Libraries for NLP
Text Normalization in NLP
Normalizing Textual Data with PythonIn this article, we will learn How to Normalizing Textual Data with Python. Let's discuss some concepts : Textual data ask systematically collected material consisting of written, printed, or electronically published words, typically either purposefully written or transcribed from speech.Text normal
7 min read
Regex Tutorial - How to write Regular Expressions?A regular expression (regex) is a sequence of characters that define a search pattern. Here's how to write regular expressions: Start by understanding the special characters used in regex, such as ".", "*", "+", "?", and more.Choose a programming language or tool that supports regex, such as Python,
6 min read
Tokenization in NLPTokenization is a fundamental step in Natural Language Processing (NLP). It involves dividing a Textual input into smaller units known as tokens. These tokens can be in the form of words, characters, sub-words, or sentences. It helps in improving interpretability of text by different models. Let's u
8 min read
Python | Lemmatization with NLTKLemmatization is an important text pre-processing technique in Natural Language Processing (NLP) that reduces words to their base form known as a "lemma." For example, the lemma of "running" is "run" and "better" becomes "good." Unlike stemming which simply removes prefixes or suffixes, it considers
6 min read
Introduction to StemmingStemming is an important text-processing technique that reduces words to their base or root form by removing prefixes and suffixes. This process standardizes words which helps to improve the efficiency and effectiveness of various natural language processing (NLP) tasks.In NLP, stemming simplifies w
6 min read
Removing stop words with NLTK in PythonNatural language processing tasks often involve filtering out commonly occurring words that provide no or very little semantic value to text analysis. These words are known as stopwords include articles, prepositions and pronouns like "the", "and", "is" and "in." While they seem insignificant, prope
5 min read
POS(Parts-Of-Speech) Tagging in NLPParts of Speech (PoS) tagging is a core task in NLP, It gives each word a grammatical category such as nouns, verbs, adjectives and adverbs. Through better understanding of phrase structure and semantics, this technique makes it possible for machines to study human language more accurately. PoS tagg
7 min read
Text Representation and Embedding Techniques
NLP Deep Learning Techniques
NLP Projects and Practice
Sentiment Analysis with an Recurrent Neural Networks (RNN)Recurrent Neural Networks (RNNs) are used in sequence tasks such as sentiment analysis due to their ability to capture context from sequential data. In this article we will be apply RNNs to analyze the sentiment of customer reviews from Swiggy food delivery platform. The goal is to classify reviews
5 min read
Text Generation using Recurrent Long Short Term Memory NetworkLSTMs are a type of neural network that are well-suited for tasks involving sequential data such as text generation. They are particularly useful because they can remember long-term dependencies in the data which is crucial when dealing with text that often has context that spans over multiple words
4 min read
Machine Translation with Transformer in PythonMachine translation means converting text from one language into another. Tools like Google Translate use this technology. Many translation systems use transformer models which are good at understanding the meaning of sentences. In this article, we will see how to fine-tune a Transformer model from
6 min read
Building a Rule-Based Chatbot with Natural Language ProcessingA rule-based chatbot follows a set of predefined rules or patterns to match user input and generate an appropriate response. The chatbot canât understand or process input beyond these rules and relies on exact matches making it ideal for handling repetitive tasks or specific queries.Pattern Matching
4 min read
Text Classification using scikit-learn in NLPThe purpose of text classification, a key task in natural language processing (NLP), is to categorise text content into preset groups. Topic categorization, sentiment analysis, and spam detection can all benefit from this. In this article, we will use scikit-learn, a Python machine learning toolkit,
5 min read
Text Summarization using HuggingFace ModelText summarization involves reducing a document to its most essential content. The aim is to generate summaries that are concise and retain the original meaning. Summarization plays an important role in many real-world applications such as digesting long articles, summarizing legal contracts, highli
4 min read
Advanced Natural Language Processing Interview QuestionNatural Language Processing (NLP) is a rapidly evolving field at the intersection of computer science and linguistics. As companies increasingly leverage NLP technologies, the demand for skilled professionals in this area has surged. Whether preparing for a job interview or looking to brush up on yo
9 min read