Open In App

PyTorch Quantization

Last Updated : 22 Jul, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Quantization is a core method for deploying large neural networks such as Llama 2 efficiently on constrained hardware, especially embedded systems and edge devices. The aim is to reduce computational and memory costs by converting high-precision floating-point representations (like float32) into lower-precision integer types (such as int8). This process significantly reduces inference time and energy usage, often with negligible impact on model accuracy, making it possible to deploy even billion-parameter models on devices where floating point is not natively supported.

Why Quantization?

  • Model Size Reduction: Quantization compresses neural network weights/activations from 32-bit floats to 8-bit integers, reducing storage and memory requirements by up to 4x.
  • Speed: Int8 matrix multiplications are much faster on most hardware (especially CPUs and embedded accelerators), accelerating inference significantly.
  • Hardware Support: Embedded systems (e.g., microcontrollers, NPUs) often lack native float math support, necessitating integer-only arithmetic.
  • Energy Efficiency: Less computation and lower memory bandwidth culminate in lower energy draw critical for mobile and IoT devices.

How Quantization Works

1. Representation Mapping

  • Original: Weights W and biases b are stored as 32-bit floats.
  • Quantization: These values are mapped into lower-precision integer representations (usually int8 for weights, int32 for biases).
  • Dequantization: Before feeding results to the next layer (often still expecting float values), results are mapped back to floating-point via the scale and zero-point parameters.

Example (as in Llama 2 7B or similar large models):

y = x W + b

Where:

  • W: Quantized to 8-bit integer (int8)
  • b: Quantized to 32-bit integer (int32, for accumulator width)
  • Computation is performed in lower precision, then dequantized for subsequent operations.

2. Formal Quantization Equation

Forward Quantization

q_x = \mathrm{round}\left(\frac{x}{\text{scale}}\right) + \text{zero\_point}

Dequantization

x \approx \text{scale} \times (q_x - \text{zero\_point})

  • scale: Determines the step size between integer values and their float counterparts.
  • zero_point: Aligns the integer representation with the network’s value distribution (e.g., maps float zero to nonzero integer).

Types of Quantization

Method

Description

Dynamic

Quantizes weights post-training; activations quantized during inference. Fast, minimal code changes.

Static (PTQ)

Requires calibration data; quantizes both weights and activations ahead of inference for best efficiency.

QAT

Quantization-Aware Training. Simulates quantization noise during training for highest post-quantization accuracy.

Implementation: PyTorch Workflow (Post-Training Quantization)

Step 1 : Data Preparation

  • Loads the MNIST handwritten digits dataset.
  • Converts images to PyTorch tensors.
  • Prepares DataLoaders for batching and iterating through data during training and testing.
Python
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor()
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset  = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
test_loader  = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

Step 2 :Define the CNN Model (with Quantization Stubs)

Python
import torch
import torch.nn as nn
import torch.quantization

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # For marking where quantization/dequantization happens
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool  = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 14 * 14, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)           # Quantize input to int8
        x = self.relu1(self.conv1(x))
        x = self.pool(self.relu2(self.conv2(x)))
        x = x.reshape(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        x = self.dequant(x)        # Dequantize output back to float32
        return x

model_fp32 = SimpleCNN()

Why QuantStub/DeQuantStub : Mark input/output boundaries for quantization and dequantization in the network so PyTorch knows where to apply quantized ops.

Step 3 : (Optional) Quick Training

  • Model trains for a couple of epochs. Even for quantization demos, decent weights are needed.
  • The code will work even if you skip training (the quantization part is independent), but accuracy will be poor.
Python
import torch.optim as optim
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_fp32.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)

print("Training (just a few epochs for demo)...")
for epoch in range(2):  # Just 2 epochs for time; increase for better accuracy!
    model_fp32.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model_fp32(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

Step 4 : Fuse Layers

Python
def fuse_model(model):
    torch.quantization.fuse_modules(model,
        [['conv1', 'relu1'], ['conv2', 'relu2'], ['fc1', 'relu3']], inplace=True)
fuse_model(model_fp32)

Rationale:

  • Fusing Conv + ReLU (or Conv+BN+ReLU) is important as it combines them into one operation, improving both the accuracy and speed of quantized models.
  • Done before quantization.

Step 5. Set Quantization Configuration

Python
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

What is qconfig?

  • It tells PyTorch how to observe and quantize the model.
  • 'fbgemm' is preferred for x86 CPUs. For ARM CPUs, use 'qnnpack'.

Step 6. Prepare for Quantization

Prepare and inserts observer modules that record the ranges of activations/weights during calibration.

Python
model_fp32.cpu()  # Quantization is CPU-only in PyTorch
torch.quantization.prepare(model_fp32, inplace=True)

Step 7 : Calibration

  • This step feeds real data through the model.
  • Observers collect min/max values to determine how to map floats to int8 (scale/zero-point).
Python
print("Calibrating...")
model_fp32.eval()
with torch.no_grad():
    for images, _ in train_loader:
        model_fp32(images)
        break  # In practical cases, use more calibration data; here just a few for demo

Step 8 : Convert to Quantized Model

  • All eligible layers (Conv, Linear, etc.) are replaced with quantized (int8) modules.
  • Model is now ready for fast, int8 inference on CPU.
Python
quantized_model = torch.quantization.convert(model_fp32, inplace=False)

Step 9 : Evaluate Accuracy

  • Runs the test data through each model.
  • Compare float32 and quantized int8 accuracy; usually, there is little loss (<1%).
Python
def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy:', 100.0 * correct / total, '%')
    return correct / total

print("\nEvaluating original (float32) model:")
evaluate(model_fp32, test_loader)  # Note: after quantization, model_fp32 is already quantized if inplace=True
print("\nEvaluating quantized model:")
evaluate(quantized_model, test_loader)

Step 10: Check Model File Sizes

  • Quantized model file is about 1/4 the size.
  • File size matches your expected storage/compression benefits from quantization.
Python
import os
torch.save(model_fp32.state_dict(), "float_model.pth")
torch.save(quantized_model.state_dict(), "quantized_model.pth")
float_size = os.path.getsize("float_model.pth") / 1024
quant_size = os.path.getsize("quantized_model.pth") / 1024
print(f"\nModel size (float32): {float_size:.1f} KB")
print(f"Model size (quantized): {quant_size:.1f} KB")

Output:

model_file_size
Output

Google Colab link : Pytorch Quantisation

Key Technical Points

  • Bias Quantization: Biases typically use int32 to prevent overflow during accumulation, as they sum many int8 products.
  • Operations on Embedded Devices: Most embedded AI accelerators strictly support integer arithmetic, making quantization the de facto deployment path.
  • Dequantization: At each layer’s output, results are dequantized (float recovered) if further float computation is needed, e.g., for softmax or loss calculation.
  • Precision and Loss: Well-calibrated quantization (especially QAT or with asymmetric scaling and dynamic range selection) can keep final accuracy very close to original float network performance.

Real-World Example: Llama 2 Quantization

  • Llama 2 7B, originally a multi-billion parameter float model, is quantized down (often to 8-bit int for weights and 32-bit int for biases and accumulators), enabling deployment on hardware-constrained servers, mobile or edge devices with minimal accuracy drop.
  • Quantization reduces model size drastically and speeds up inference, since y= xW+ b (with W, b quantized) becomes a pure int8/int32 computation, replaced by a single scale factor per layer for seamless dequantization.

Similar Reads