Text Generation using Fnet
Last Updated :
31 Jul, 2025
Text generation in natural language processing (NLP) has improved significantly with Transformer-based models like GPT and BERT. These models use self-attention to understand how words relate to each other in a sentence which is very slow and costly, especially when working with long sequences of text. FNet solves this problem by replacing self-attention with the Fourier Transform. This method is more efficient and doesn't need extra parameters, making it faster while still providing good results.
Why FNet is Effective for Text Generation
- Reduced Complexity: Traditional Transformer models use attention which can become computationally expensive with large input sequences. It reduces complexity without sacrificing performance.
- Improved Efficiency: It can handle longer input sequences more efficiently, making it useful for applications requiring the processing of large texts like generating entire articles or scripts.
- Versatility: It is not just limited to text generation it can also be applied to tasks like language translation and text classification making it a versatile tool in NLP.
Implementing FNet for Text Generation in Python
Lets see the implementation of FNet for text generation:
Step 1: Installing and Importing Libraries
We will install below libraries if they are not available in our environment using:
!pip install datasets
!pip install torch[transformers]
Here we will be using PyTorch, Numpy and Pandas libraries for the implementation.
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft as fft
import numpy as np
import pandas as pd
import re
Additionally we define the device variable that ensures computation is done on GPU if available otherwise it defaults to CPU.
Python
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
Output:
cuda
Step 2: Loading Data
Here we will load the wikitext-103-raw-v1 version of the WikiText dataset which contains text data from Wikipedia articles, without any additional processing applied to it. Also we'll be using the datasets library which makes it easy to access and work with datasets from Hugging Face.
Python
from datasets import load_dataset
datasets = load_dataset('wikitext','wikitext-2-raw-v1')
Step 3: Data Preprocessing
Before feeding the raw text into the model, it's important to clean and preprocess the data. Here we decalare a preprocess_text function which will:
- Make all the words in the sentence lowercase
- Remove any special characters
- Replace any multiple white spaces
After defining the preprocess_text function, we apply it to each text sample in the dataset using the map function from the datasets library. Additionally, we use the filter function to keep only those text sequences that have more than 20 words, ensuring that we discard any short, irrelevant sequences.
Python
def preprocess_text(sentence):
text = sentence['text'].lower()
text = re.sub('[^a-z?!.,]', ' ', text)
text = re.sub('\s\s+', ' ', text)
sentence['text'] = text
return sentence
datasets['train'] = datasets['train'].map(preprocess_text)
datasets['test'] = datasets['test'].map(preprocess_text)
datasets['validation'] = datasets['validation'].map(preprocess_text)
datasets['train'] = datasets['train'].filter(lambda x: len(x['text']) > 20)
datasets['test'] = datasets['test'].filter(lambda x: len(x['text']) > 20)
datasets['validation'] = datasets['validation'].filter(
lambda x: len(x['text']) > 20)
Step 4: Tokenization
For tokenization, we use a pretrained tokenizer from Hugging Face. The distilbert-base-uncased-finetuned-sst-2-english tokenizer is loaded using AutoTokenizer.from_pretrained(). This converts raw text into tokenized sequences suitable for model training.
- Tokenization Function: Define a function to tokenize each sentence.
- Apply Tokenizer: Use the map() function to apply the tokenizer across the dataset.
- Remove Original Text: Remove the original text column using remove_columns() to retain only tokenized inputs.
- Padding: Ensure consistent input lengths across batches with DataCollatorWithPadding.
Python
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
def tokenize(sentence):
sentence = tokenizer(sentence['text'], truncation=True)
return sentence
tokenized_inputs = datasets['test'].map(tokenize)
tokenized_inputs = tokenized_inputs.remove_columns(['text'])
batch = 16
data_collator = DataCollatorWithPadding(
tokenizer=tokenizer, padding=True, return_tensors="pt")
dataloader = DataLoader(
tokenized_inputs, batch_size=batch, collate_fn=data_collator)
Step 5: Embedding and Positional Encoding
Here we create two class, One for positional encoding and one for embedding.
- Positional Encoding: Generate positional encodings to provide the model with information about token positions.
- Embedding: The PositionalEmbedding class takes tokenized inputs, embeds them and adds the positional encoding to capture sequential information effectively.
Python
class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model, max_sequence_length):
super().__init__()
self.d_model = d_model
self.max_sequence_length = max_sequence_length
self.positional_encoding = self.create_positional_encoding().to(device)
def create_positional_encoding(self):
positional_encoding = np.zeros((self.max_sequence_length, self.d_model))
for pos in range(self.max_sequence_length):
for i in range(0, self.d_model, 2):
positional_encoding[pos, i] = np.sin(pos / (10000 ** ((2 * i) / self.d_model)))
if i + 1 < self.d_model:
positional_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((2 * i) / self.d_model)))
return torch.from_numpy(positional_encoding).float()
def forward(self, x):
expanded_tensor = torch.unsqueeze(self.positional_encoding, 0).expand(x.size(0), -1, -1).to(device)
return x.to(device) + expanded_tensor[:,:x.size(1), :]
class PositionalEmbedding(nn.Module):
def __init__(self, sequence_length, vocab_size, embed_dim):
super(PositionalEmbedding, self).__init__()
self.token_embeddings = nn.Embedding(vocab_size, embed_dim)
self.position_embeddings = PositionalEncoding(embed_dim,sequence_length)
def forward(self, inputs):
embedded_tokens = self.token_embeddings(inputs).to(device)
embedded_positions = self.position_embeddings(embedded_tokens).to(device)
return embedded_positions.to(device)
Step 6: Create FNet Encoder
The FNet Encoder is designed based on the FNet architecture, using Fourier Transforms to process the input sequence.
- Fourier Transform: Applies fft.fft2 to the input and the real part of the result is added back to the original input.
- Normalization: After applying Fourier Transform, layer normalization (self.layernorm_1) is used.
- Dense Projection: Two linear layers with ReLU activation (self.dense_proj) project the input into a different dimension.
- Final Normalization: A second layer normalization (self.layernorm_2) is applied to the output.
Python
class FNetEncoder(nn.Module):
def __init__(self,embed_dim, dense_dim):
super(FNetEncoder,self).__init__()
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.dense_proj = nn.Sequential(nn.Linear(self.embed_dim,self.dense_dim), nn.ReLU(), nn.Linear(self.dense_dim,self.embed_dim))
self.layernorm_1 = nn.LayerNorm(self.embed_dim)
self.layernorm_2 = nn.LayerNorm(self.embed_dim)
def forward(self,inputs):
fft_result = fft.fft2(inputs)
fft_real = fft_result.real.float()
proj_input = self.layernorm_1 (inputs + fft_real)
proj_output = self.dense_proj(proj_input)
return self.layernorm_2(proj_input +proj_output)
Step 7 : Create FnetDecoder
The FNet Decoder is designed based on the FNet architecture and includes multi-head attention mechanisms to process the input sequence.
- Multi-Head Attention: self.attention_1 attends to decoder inputs with a causal mask to prevent future token information while self.attention_2 attends to encoder outputs with an optional key padding mask.
- Normalization: Layer normalization is applied after each attention mechanism to stabilize intermediate representations.
- Dense Projection: Two linear layers with ReLU activation (self.dense_proj) project the output to a different dimension.
- Final Normalization: A second layer normalization (self.layernorm_3) is applied to the final output.
Python
class FNetDecoder(nn.Module):
def __init__(self,embed_dim,dense_dim,num_heads):
super(FNetDecoder,self).__init__()
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.attention_1 = nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)
self.attention_2 = nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)
self.dense_proj = nn.Sequential(nn.Linear(embed_dim, dense_dim),nn.ReLU(),nn.Linear(dense_dim, embed_dim))
self.layernorm_1 = nn.LayerNorm(embed_dim)
self.layernorm_2 = nn.LayerNorm(embed_dim)
self.layernorm_3 = nn.LayerNorm(embed_dim)
def forward(self, inputs, encoder_outputs, mask=None):
causal_mask = nn.Transformer.generate_square_subsequent_mask(inputs.size(1)).to(device)
attention_output_1, _ = self.attention_1(inputs, inputs, inputs, attn_mask=causal_mask)
out_1 = self.layernorm_1(inputs + attention_output_1)
if mask != None:
attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs, key_padding_mask =torch.transpose(mask, 0, 1).to(device))
else:
attention_output_2, _ = self.attention_2(out_1, encoder_outputs, encoder_outputs)
out_2 = self.layernorm_2(out_1 + attention_output_2)
proj_output = self.dense_proj(out_2)
return self.layernorm_3(out_2 + proj_output)
Step 8: FNet Model
The FNet Model combines positional encoding, FNet encoder and FNet decoder components.
- Initialization (__init__ method): Initializes model with parameters like embed_dim, latent_dim, num_heads and vocab_size.
- Encoder: Processes encoder_inputs through positional encoding and four FNetEncoder layers sequentially.
- Decoder: Processes decoder_inputs, encoder_output and attention mask through four FNetDecoder layers.
- Forward Pass: Takes encoder_inputs, decoder_inputs and attention mask and passes them through encoder and decoder layers to get the final output.
Python
class FNetModel(nn.Module):
def __init__(self, max_length, vocab_size, embed_dim, latent_dim, num_heads):
super(FNetModel, self).__init__()
self.encoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
self.encoder1 = FNetEncoder(embed_dim, latent_dim)
self.encoder2 = FNetEncoder(embed_dim, latent_dim)
self.encoder3 = FNetEncoder(embed_dim, latent_dim)
self.encoder4 = FNetEncoder(embed_dim, latent_dim)
self.decoder_inputs = PositionalEmbedding(max_length,vocab_size, embed_dim)
self.decoder1 = FNetDecoder(embed_dim, latent_dim, num_heads)
self.decoder2 = FNetDecoder(embed_dim, latent_dim, num_heads)
self.decoder3 = FNetDecoder(embed_dim, latent_dim, num_heads)
self.decoder4 = FNetDecoder(embed_dim, latent_dim, num_heads)
self.dropout = nn.Dropout(0.5)
self.dense = nn.Linear(embed_dim, vocab_size)
def encoder(self,encoder_inputs):
x_encoder = self.encoder_inputs(encoder_inputs)
x_encoder = self.encoder1(x_encoder)
x_encoder = self.encoder2(x_encoder)
x_encoder = self.encoder3(x_encoder)
x_encoder = self.encoder4(x_encoder)
return x_encoder
def decoder(self,decoder_inputs,encoder_output,att_mask):
x_decoder = self.decoder_inputs(decoder_inputs)
x_decoder = self.decoder1(x_decoder, encoder_output,att_mask)
x_decoder = self.decoder2(x_decoder, encoder_output,att_mask)
x_decoder = self.decoder3(x_decoder, encoder_output,att_mask)
x_decoder = self.decoder4(x_decoder, encoder_output,att_mask)
decoder_outputs = self.dense(x_decoder)
return decoder_outputs
def forward(self, encoder_inputs, decoder_inputs,att_mask = None):
encoder_output = self.encoder(encoder_inputs)
decoder_output = self.decoder(decoder_inputs,encoder_output,att_mask=None)
return decoder_output
Step 9: Initialize Model
In this step, we initialize the model by declaring the necessary hyperparameters and passing them to the model class.
- max_length: Maximum sequence length for inputs.
- vocab_size: Size of the vocabulary.
- embed_dim: Embedding dimension for tokens.
- latent_dim: Dimension of the latent space.
- num_heads: Number of attention heads in multi-head attention.
- Model Initialization: Instantiate the FNet model with the defined hyperparameters.
Python
MAX_LENGTH = 512
VOCAB_SIZE = len(tokenizer.vocab)
EMBED_DIM = 256
LATENT_DIM = 100
NUM_HEADS = 4
fnet_model = FNetModel(MAX_LENGTH, VOCAB_SIZE, EMBED_DIM, LATENT_DIM, NUM_HEADS).to(device)
Step 10: Train the Model
Here we train the model by defining the optimizer, loss function and iterating through the training data.
- Optimizer: We use the Adam optimizer to update the model's parameters during training which adapts the learning rate based on the gradient.
- Loss Function: Cross Entropy Loss is used as the loss function which is applied in classification tasks like sequence generation.
- Gradient Calculation: Before each step, gradients are zeroed using optimizer.zero_grad().
- Backpropagation: Gradients are calculated using loss.backward() and the optimizer updates the model's weights with optimizer.step().
- Training Loop: The training process is repeated for 10 epochs during which the model learns to predict the output sequences more accurately.
Python
optimizer = torch.optim.Adam(fnet_model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)
epochs = 100
for epoch in range(epochs):
train_loss = 0
for batch in dataloader:
encoder_inputs_tensor = batch['input_ids'][:,:-1].to(device)
decoder_inputs_tensor = batch['input_ids'][:,1:].to(device)
att_mask = batch['attention_mask'][:,:-1].to(device).to(dtype=bool)
optimizer.zero_grad()
outputs = fnet_model(encoder_inputs_tensor, decoder_inputs_tensor,att_mask)
decoder_inputs_tensor.masked_fill(batch['attention_mask'][:,1:].ne(1).to(device), -100).to(device)
loss = criterion(outputs.view(-1, VOCAB_SIZE), decoder_inputs_tensor.reshape(-1))
train_loss = train_loss + loss.item()
loss.backward()
optimizer.step()
print (f" epoch: {epoch}, train_loss : {train_loss}")
Output:
Training the modelStep 11: Use Model for Text Generation
To perform text generation using a Transformer decoder, we can use autoregressive decoding where we iteratively generate one token at a time by sampling from the model's output distribution and feeding the sampled token back into the input for the next step. We use the encoder part of the model to generate context vector for a given input token.
Python
MAX_LENGTH =100
def decode_sentence(input_sentence, fnet_model):
fnet_model.eval()
with torch.no_grad():
tokenized_input_sentence = torch.tensor(tokenizer(preprocess_text(input_sentence)['text'])['input_ids']).to(device)#
tokenzied_target_sentence = torch.tensor([101]).to(device)
current_text = preprocess_text(input_sentence)['text']
for i in range(MAX_LENGTH):
predictions = fnet_model(tokenized_input_sentence[:-1].unsqueeze(0),tokenzied_target_sentence.unsqueeze(0))
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_token = tokenizer.decode(predicted_index)
if predicted_token == "[SEP]":
break
current_text += " "+ predicted_token
tokenized_target_sentence = torch.cat([tokenzied_target_sentence, torch.tensor([predicted_index]).to(device)], 0).to(device)
tokenized_input_sentence = torch.tensor(tokenizer(current_text)['input_ids']).to(device)
return current_text
decode_sentence({'text': 'How are you ?'}, fnet_model)
Output:
how are you ? ufc ufc imp ufc ufc ufc ufc ufc ufc ufc own hey own own own own ufc
In order to get a better output we need to train the model with large amount of data and for significant time which will require GPUs.
Applications of Text Generation using FNet
- Long-Form Content Generation: FNet's ability to handle long sequences efficiently makes it ideal for generating large amounts of text such as articles, blogs or reports where traditional Transformer models may face performance issues.
- Machine Translation: The efficiency of FNet allows it to handle long text sequences in translation tasks where capturing global context is important. It can be applied to translate long paragraphs or documents effectively.
- Text Summarization: It can be applied to extractive or abstractive summarization, processing long documents and summarizing them into shorter, meaningful content with reduced computational cost.
- Sentiment Analysis: By using the Fourier Transform for efficient sequence processing, it can be applied to analyze sentiment over longer contexts such as reviews or feedback that may span multiple sentences.
- Speech-to-Text: FNet's scalability can be extended to applications in speech recognition and transcription, processing long audio sequences that are converted to text and enabling real-time speech-to-text services.
Challenges of Text Generation using FNet:
- Limited Interpretability: Unlike self-attention where each token’s relevance to others is explicitly captured, the Fourier Transform lacks clear interpretability making it harder to understand how the model find its outputs.
- Adaptability to Complex Contexts: While it performs well with long sequences, it may struggle with capturing complex relationships in highly contextual or domain-specific tasks where self-attention excels in modeling local dependencies.
- Loss of Fine-Grained Attention: By replacing self-attention, it may miss out on fine-grained relationships between tokens that would typically be highlighted in traditional attention mechanisms which could impact text generation quality in certain cases.
- Smaller Community Support and Research: As a relatively newer architecture, it lacks the extensive research and community support that Transformers like BERT and GPT have accumulated over time which may limit available resources and practical use cases.
You can download source code from here.
Similar Reads
Natural Language Processing (NLP) Tutorial Natural Language Processing (NLP) is a branch of Artificial Intelligence (AI) that helps machines to understand and process human languages either in text or audio form. It is used across a variety of applications from speech recognition to language translation and text summarization.Natural Languag
5 min read
Introduction to NLP
Natural Language Processing (NLP) - OverviewNatural Language Processing (NLP) is a field that combines computer science, artificial intelligence and language studies. It helps computers understand, process and create human language in a way that makes sense and is useful. With the growing amount of text data from social media, websites and ot
9 min read
NLP vs NLU vs NLGNatural Language Processing(NLP) is a subset of Artificial intelligence which involves communication between a human and a machine using a natural language than a coded or byte language. It provides the ability to give instructions to machines in a more easy and efficient manner. Natural Language Un
3 min read
Applications of NLPAmong the thousands and thousands of species in this world, solely homo sapiens are successful in spoken language. From cave drawings to internet communication, we have come a lengthy way! As we are progressing in the direction of Artificial Intelligence, it only appears logical to impart the bots t
6 min read
Why is NLP important?Natural language processing (NLP) is vital in efficiently and comprehensively analyzing text and speech data. It can navigate the variations in dialects, slang, and grammatical inconsistencies typical of everyday conversations. Table of Content Understanding Natural Language ProcessingReasons Why NL
6 min read
Phases of Natural Language Processing (NLP)Natural Language Processing (NLP) helps computers to understand, analyze and interact with human language. It involves a series of phases that work together to process language and each phase helps in understanding structure and meaning of human language. In this article, we will understand these ph
7 min read
The Future of Natural Language Processing: Trends and InnovationsThere are no reasons why today's world is thrilled to see innovations like ChatGPT and GPT/ NLP(Natural Language Processing) deployments, which is known as the defining moment of the history of technology where we can finally create a machine that can mimic human reaction. If someone would have told
7 min read
Libraries for NLP
Text Normalization in NLP
Normalizing Textual Data with PythonIn this article, we will learn How to Normalizing Textual Data with Python. Let's discuss some concepts : Textual data ask systematically collected material consisting of written, printed, or electronically published words, typically either purposefully written or transcribed from speech.Text normal
7 min read
Regex Tutorial - How to write Regular Expressions?A regular expression (regex) is a sequence of characters that define a search pattern. Here's how to write regular expressions: Start by understanding the special characters used in regex, such as ".", "*", "+", "?", and more.Choose a programming language or tool that supports regex, such as Python,
6 min read
Tokenization in NLPTokenization is a fundamental step in Natural Language Processing (NLP). It involves dividing a Textual input into smaller units known as tokens. These tokens can be in the form of words, characters, sub-words, or sentences. It helps in improving interpretability of text by different models. Let's u
8 min read
Python | Lemmatization with NLTKLemmatization is an important text pre-processing technique in Natural Language Processing (NLP) that reduces words to their base form known as a "lemma." For example, the lemma of "running" is "run" and "better" becomes "good." Unlike stemming which simply removes prefixes or suffixes, it considers
6 min read
Introduction to StemmingStemming is an important text-processing technique that reduces words to their base or root form by removing prefixes and suffixes. This process standardizes words which helps to improve the efficiency and effectiveness of various natural language processing (NLP) tasks.In NLP, stemming simplifies w
6 min read
Removing stop words with NLTK in PythonNatural language processing tasks often involve filtering out commonly occurring words that provide no or very little semantic value to text analysis. These words are known as stopwords include articles, prepositions and pronouns like "the", "and", "is" and "in." While they seem insignificant, prope
5 min read
POS(Parts-Of-Speech) Tagging in NLPParts of Speech (PoS) tagging is a core task in NLP, It gives each word a grammatical category such as nouns, verbs, adjectives and adverbs. Through better understanding of phrase structure and semantics, this technique makes it possible for machines to study human language more accurately. PoS tagg
7 min read
Text Representation and Embedding Techniques
NLP Deep Learning Techniques
NLP Projects and Practice
Sentiment Analysis with an Recurrent Neural Networks (RNN)Recurrent Neural Networks (RNNs) are used in sequence tasks such as sentiment analysis due to their ability to capture context from sequential data. In this article we will be apply RNNs to analyze the sentiment of customer reviews from Swiggy food delivery platform. The goal is to classify reviews
5 min read
Text Generation using Recurrent Long Short Term Memory NetworkLSTMs are a type of neural network that are well-suited for tasks involving sequential data such as text generation. They are particularly useful because they can remember long-term dependencies in the data which is crucial when dealing with text that often has context that spans over multiple words
4 min read
Machine Translation with Transformer in PythonMachine translation means converting text from one language into another. Tools like Google Translate use this technology. Many translation systems use transformer models which are good at understanding the meaning of sentences. In this article, we will see how to fine-tune a Transformer model from
6 min read
Building a Rule-Based Chatbot with Natural Language ProcessingA rule-based chatbot follows a set of predefined rules or patterns to match user input and generate an appropriate response. The chatbot canât understand or process input beyond these rules and relies on exact matches making it ideal for handling repetitive tasks or specific queries.Pattern Matching
4 min read
Text Classification using scikit-learn in NLPThe purpose of text classification, a key task in natural language processing (NLP), is to categorise text content into preset groups. Topic categorization, sentiment analysis, and spam detection can all benefit from this. In this article, we will use scikit-learn, a Python machine learning toolkit,
5 min read
Text Summarization using HuggingFace ModelText summarization involves reducing a document to its most essential content. The aim is to generate summaries that are concise and retain the original meaning. Summarization plays an important role in many real-world applications such as digesting long articles, summarizing legal contracts, highli
4 min read
Advanced Natural Language Processing Interview QuestionNatural Language Processing (NLP) is a rapidly evolving field at the intersection of computer science and linguistics. As companies increasingly leverage NLP technologies, the demand for skilled professionals in this area has surged. Whether preparing for a job interview or looking to brush up on yo
9 min read