Rate this Page

(beta) Dynamic Quantization on an LSTM Word Language Model#

Created On: Oct 07, 2019 | Last Updated: Jun 04, 2025 | Last Verified: Nov 05, 2024

Author: James Reed

Edited by: Seth Weidman

Introduction#

Quantization involves converting the weights and activations of your model from float to int, which can result in smaller model size and faster inference with only a small hit to accuracy.

In this tutorial, we will apply the easiest form of quantization - dynamic quantization - to an LSTM-based next word-prediction model, closely following the word language model from the PyTorch examples.

# imports
import os
from io import open
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

1. Define the model#

Here we define the LSTM model architecture, following the model from the word language model example.

class LSTMModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        return decoded, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

2. Load in the text data#

Next, we load the Wikitext-2 dataset into a Corpus, again following the preprocessing from the word language model example.

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            idss = []
            for line in f:
                words = line.split() + ['<eos>']
                ids = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids).type(torch.int64))
            ids = torch.cat(idss)

        return ids

model_data_filepath = 'data/'

corpus = Corpus(model_data_filepath + 'wikitext-2')

3. Load the pretrained model#

This is a tutorial on dynamic quantization, a quantization technique that is applied after a model has been trained. Therefore, we’ll simply load some pretrained weights into this model architecture; these weights were obtained by training for five epochs using the default settings in the word language model example.

Before running this tutorial, download the required pre-trained model:

wget https://s3.amazonaws.com/pytorch-tutorial-assets/word_language_model_quantize.pth

Place the downloaded file in the data directory or update the model_data_filepath accordingly.

ntokens = len(corpus.dictionary)

model = LSTMModel(
    ntoken = ntokens,
    ninp = 512,
    nhid = 256,
    nlayers = 5,
)

model.load_state_dict(
    torch.load(
        model_data_filepath + 'word_language_model_quantize.pth',
        map_location=torch.device('cpu'),
        weights_only=True
        )
    )

model.eval()
print(model)
LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): LSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): Linear(in_features=256, out_features=33278, bias=True)
)

Now let’s generate some text to ensure that the pretrained model is working properly - similarly to before, we follow here

input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000

with open(model_data_filepath + 'out.txt', 'w') as outf:
    with torch.no_grad():  # no tracking history
        for i in range(num_words):
            output, hidden = model(input_, hidden)
            word_weights = output.squeeze().div(temperature).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            input_.fill_(word_idx)

            word = corpus.dictionary.idx2word[word_idx]

            outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))

            if i % 100 == 0:
                print('| Generated {}/{} words'.format(i, 1000))

with open(model_data_filepath + 'out.txt', 'r') as outf:
    all_output = outf.read()
    print(all_output)
| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'Spalato' b'one' b'Saturday' b'At' b'peerage' b'Battle' b'.' b'The' b'air' b'of' b'these' b'Persian' b'artists' b'has' b'the' b'only' b'interests' b'of' b'hosts' b':'
b'"' b'All' b'of' b'it' b'is' b'true' b'"' b',' b'where' b'two' b'giant' b'politicians' b'would' b'be' b'affected' b'against' b'disturbing' b'2003' b'in' b'some'
b'mud' b'.' b'<eos>' b'In' b'1969' b',' b'the' b'zero' b'newly' b'soldier' b'International' b'Blue' b'used' b'Richard' b'an' b'0' b'@.@' b'0' b'km' b'('
b'3' b'@.@' b'2' b'ft' b')' b'pattern' b'wall' b'from' b'Catholic' b'angles' b'of' b'Ceres' b',' b'servicemen' b',' b'K' b'Ireland' b',' b'and' b'stores'
b'.' b'At' b'a' b'estimated' b'larger' b'increase' b'were' b'introduced' b'into' b'rampart' b',' b'one' b'of' b'the' b'latter' b'seats' b'.' b'Each' b'original' b'species'
b'bears' b'in' b'Strategic' b'Island' b'2' b'p.m.' b'purre' b',' b'27' b'\xc2\xb0' b'fledging' b',' b'ranges' b'and' b'Raeburn' b',' b'692' b'visual' b'or' b'tree'
b',' b'so' b'rotation' b'large' b'puzzles' b',' b'a' b'duration' b',' b'and' b'the' b'eleventh' b'<unk>' b'(' b'promoting' b'Latin' b')' b'red' b'and' b'transportation'
b',' b'and' b'done' b'options' b',' b'which' b'also' b'mention' b'frequency' b'from' b'Argyle' b'orbits' b'(' b'product' b'offside' b'568' b')' b'.' b'A' b'sculptures'
b'report' b'probable' b'history' b'in' b'AIDS' b'race' b',' b'including' b'Grade' b'II' b',' b'"' b'soft' b'"' b'@-@' b'American' b',' b'Keys' b'and' b'<unk>'
b'to' b'prevent' b'Eagles' b'production' b'(' b'Llanilltern' b')' b'.' b'In' b'the' b'\xe1\x83\xab' b',' b'Oslo' b',' b'there' b'are' b'1' b'on' b'December' b'29'
b',' b'1849' b',' b'from' b'Summer' b'ocean' b'predator' b'of' b'1995' b'.' b'In' b'common' b'predecessors' b',' b'Hydnellum' b'species' b'of' b'<unk>' b'is' b'banned'
b'.' b'However' b',' b'Ava' b'as' b'form' b'queries' b'of' b'common' b'starling' b',' b'including' b'Mars' b'saying' b'implies' b'consistent' b'scheme' b'and' b'hosts' b'for'
b'<unk>' b'.' b'There' b'are' b'only' b'two' b'main' b'starlings' b'toward' b'birds' b'over' b'the' b'main' b'length' b'of' b'the' b'decade' b'.' b'The' b'last'
b'half' b'mainly' b'to' b'speculated' b'that' b'species' b'of' b'1440' b'related' b'is' b'longer' b'characteristic' b'at' b'metalwork' b'.' b'When' b'in' b'barge' b',' b'they'
b'play' b'1' b'\xe2\x80\x93' b'24' b'caliber' b'(' b'<unk>' b'above' b'21' b'@.@' b'5' b'in' b')' b'to' b'bowl' b'forces' b'.' b'All' b'species' b'of'
b'looking' b'around' b'average' b'feathers' b',' b'60' b'%' b'of' b'950' b'households' b'which' b'are' b'sufficiently' b'postponed' b'from' b'take' b'since' b'the' b'forests' b'area'
b'at' b'cardiac' b'times' b',' b'will' b'be' b'outside' b'to' b'be' b'their' b'enclosed' b'.' b'A' b'long' b'NHC' b'or' b'females' b'may' b'be' b'released'
b'to' b'strike' b'for' b'the' b'Tallest' b'cross' b'.' b'Since' b'they' b'feature' b'all' b'is' b'limited' b',' b'they' b',' b'and' b'there' b'then' b'have'
b'anyway' b'suffered' b'them' b'deep' b',' b'so' b'it' b'will' b'be' b'distinguished' b'by' b'a' b'white' b'matter' b',' b'so' b'will' b'be' b'distinguished' b'as'
b'au' b'Skye' b'.' b'The' b'kakapo' b'Maximum' b'the' b'eggs' b'a' b'maximum' b'species' b'of' b'ranges' b'over' b'the' b'eastern' b'government' b'and' b'linking' b'removing'
b'power' b'Nelson' b'.' b'<eos>' b'The' b'printers' b'of' b'quality' b'has' b'been' b'<unk>' b'inhabited' b'by' b'Ireland' b'.' b'polymerase' b'no' b'woodland' b',' b'causing'
b'a' b'10' b'@.@' b'8' b'92' b'progression' b'as' b'an' b'accumulated' b'belonging' b'to' b'the' b'colour' b'.' b'Both' b'chicks' b'were' b'Koss' b'guru' b'and'
b'<unk>' b'numerous' b'islands' b',' b'not' b'Cove' b'.' b'One' b'regions' b'usually' b'do' b'in' b'the' b'23' b'hours' b'from' b'Abby' b',' b'or' b'2'
b'@.@' b'5' b'million' b'miles' b'(' b'4' b'@.@' b'8' b'in' b')' b'high' b'.' b'\xc2\xb5g' b'damage' b"'" b'cargo' b'comes' b'in' b'southwestern' b'elevation'
b',' b'which' b'marked' b'everything' b'of' b'sea' b'working' b'at' b'07' b'to' b'6' b'p.m.' b'.' b'It' b'usually' b'is' b'<unk>' b'through' b'strength' b'another'
b'cell' b'except' b'it' b'of' b'and' b'are' b'transmitted' b'to' b'distinguish' b'his' b'body' b'.' b'<unk>' b'(' b'red' b'sound' b',' b'John' b'males' b')'
b'is' b'very' b'much' b'common' b'towards' b'span' b'or' b'other' b'@-@' b'frequency' b',' b'and' b'resembles' b'their' b'length' b'<unk>' b'yield' b'of' b'housing' b'more'
b'than' b'$' b'5' b'million' b',' b'something' b'following' b'the' b'sale' b'of' b'Hilton' b'2' b'on' b'Feak' b'.' b'Males' b'1754' b'eggs' b'because' b'of'
b'their' b'ion' b'1500' b'or' b'powerful' b'snakes' b'.' b'protein' b'white' b'efforts' b'annually' b',' b'typically' b'in' b'birds' b',' b'only' b'high' b'amino' b'miners'
b',' b'are' b'"' b'occasionally' b'white' b'than' b'constant' b'cross' b'and' b'can' b'better' b'be' b'obviously' b'traces' b'to' b'rounded' b'at' b'Lisa' b'in' b'London'
b'where' b'they' b'have' b'killed' b'tribute' b'against' b'numbers' b'to' b'drill' b'that' b'is' b'necessary' b'.' b'"' b'They' b'suspended' b'areas' b',' b'with' b'<unk>'
b'reflect' b'January' b'4' b',' b'1910' b'swinging' b',' b'choose' b'as' b'a' b'single' b'source' b'.' b'These' b'over' b'humans' b'that' b'are' b'off' b'producing'
b'city' b'by' b'lasers' b'that' b'have' b'their' b'successful' b'ability' b'for' b'the' b'feeding' b'models' b'(' b'the' b'Insular' b'Eurasian' b'section' b')' b',' b'coupled'
b'for' b'both' b'different' b'<unk>' b',' b'they' b'can' b'take' b'invertebrates' b'.' b'In' b'males' b',' b'gene' b'@-@' b'tailed' b'LEDs' b'are' b'distributed' b','
b'regardless' b'the' b'planets' b'made' b'to' b'feeding' b'.' b'Also' b'by' b'their' b'smooth' b'orbit' b',' b'they' b'<unk>' b'high' b'or' b'periods' b',' b'up'
b'vanquishing' b'tree' b'muscle' b',' b'proportional' b'vintage' b'or' b'<unk>' b'accessories' b'.' b'This' b'suffered' b'some' b'common' b'hairstyle' b'made' b'as' b'advise' b'at' b'each'
b'powerful' b'vertical' b'low' b'or' b'flightless' b'.' b'While' b'the' b'other' b'of' b'them' b'females' b'is' b'eaten' b',' b'these' b'types' b'of' b'Massive' b'lands'
b'is' b'quite' b'smaller' b'than' b'the' b'bird' b'.' b'They' b'may' b'be' b'largely' b'able' b'to' b'participate' b',' b'with' b'some' b'species' b'of' b'Vickers'
b'starlings' b',' b'which' b'appears' b'as' b'they' b'has' b'sufficiently' b'given' b'North' b'75' b'.' b'<eos>' b'Common' b'starlings' b'are' b'a' b'planet' b'of' b'striped'
b'Archaeological' b'contraception' b'.' b'A' b'feeding' b'extent' b'from' b'large' b'populations' b'of' b'range' b',' b'or' b'women' b',' b'flash' b'cells' b',' b'<unk>' b','
b'starlings' b',' b'<unk>' b',' b'Colorado' b',' b'or' b'black' b'reagent' b',' b'consists' b'of' b'roles' b',' b'such' b'as' b'Cramp' b'as' b'produce' b'.'
b'chemin' b'blind' b'few' b'more' b'correctly' b'immediately' b'or' b'ambitions' b'is' b'greater' b'urban' b',' b'and' b'both' b'them' b',' b'paler' b',' b'motion' b','
b'are' b'also' b'designed' b'to' b'become' b',' b'others' b'doping' b'qualified' b'.' b'publicist' b'<unk>' b'may' b'have' b'any' b'negative' b'brown' b'fixtures' b'due' b'to'
b'both' b'620' b'activity' b',' b'a' b'variety' b'one' b'of' b'the' b'structural' b',' b'<unk>' b',' b'or' b'agricultural' b'low' b',' b'for' b'passionate' b'convective'
b'gleba' b'.' b'Regardless' b',' b'they' b'may' b'have' b'interested' b'in' b'noisy' b'Kong' b',' b'as' b'they' b'can' b'be' b'however' b'on' b'eight' b'or'
b'more' b'careful' b'females' b',' b'equipped' b'by' b'<unk>' b',' b'to' b'<unk>' b'by' b'krypton' b'.' b'Amongst' b'Helen' b'of' b'insects' b',' b'they' b'also'
b'contain' b'cytogenetics' b'throughout' b'males' b',' b'wider' b'depictions' b'such' b'as' b'French' b',' b'mechanical' b'or' b'smell' b',' b'interred' b',' b'structures' b',' b'<unk>'
b',' b'expansion' b',' b'Sabine' b',' b'and' b'Celtic' b'serials' b'<unk>' b'<unk>' b'and' b'Hartington' b'as' b'Principe' b'Latino' b'.' b'\xe2\x88\x92' b'the' b'total' b'agricultural'
b',' b'no' b',' b'marble' b'Burj' b',' b'parks' b',' b'Uh' b',' b'sorts' b',' b'<unk>' b',' b'varies' b',' b'and' b'other' b'birds' b'are'
b'"' b'dwarf' b'organic' b',' b'"' b'partly' b',' b'if' b'how' b'there' b'was' b'slow' b'one' b'of' b'the' b'most' b'important' b',' b'<unk>' b','

It’s no GPT-2, but it looks like the model has started to learn the structure of language!

We’re almost ready to demonstrate dynamic quantization. We just need to define a few more helper functions:

bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1

# create test data set
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into ``bsz`` parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the ``bsz`` batches.
    return data.view(bsz, -1).t().contiguous()

test_data = batchify(corpus.test, eval_batch_size)

# Evaluation functions
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

def repackage_hidden(h):
  """Wraps hidden states in new Tensors, to detach them from their history."""

  if isinstance(h, torch.Tensor):
      return h.detach()
  else:
      return tuple(repackage_hidden(v) for v in h)

def evaluate(model_, data_source):
    # Turn on evaluation mode which disables dropout.
    model_.eval()
    total_loss = 0.
    hidden = model_.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model_(data, hidden)
            hidden = repackage_hidden(hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

4. Test dynamic quantization#

Finally, we can call torch.quantization.quantize_dynamic on the model! Specifically,

  • We specify that we want the nn.LSTM and nn.Linear modules in our model to be quantized

  • We specify that we want weights to be converted to int8 values

import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

The model looks the same; how has this benefited us? First, we see a significant reduction in model size:

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)
Size (MB): 113.944455
Size (MB): 79.738939

Second, we see faster inference time, with no difference in evaluation loss:

Note: we set the number of threads to one for single threaded comparison, since quantized models run single threaded.

torch.set_num_threads(1)

def time_model_evaluation(model, test_data):
    s = time.time()
    loss = evaluate(model, test_data)
    elapsed = time.time() - s
    print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))

time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)
loss: 5.167
elapsed time (seconds): 193.2
loss: 5.168
elapsed time (seconds): 114.3

Running this locally on a MacBook Pro, without quantization, inference takes about 200 seconds, and with quantization it takes just about 100 seconds.

Conclusion#

Dynamic quantization can be an easy way to reduce model size while only having a limited effect on accuracy.

Thanks for reading! As always, we welcome any feedback, so please create an issue here if you have any.

Total running time of the script: ( 5 minutes 16.613 seconds)

Gallery generated by Sphinx-Gallery