.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "advanced/dynamic_quantization_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_advanced_dynamic_quantization_tutorial.py: (beta) Dynamic Quantization on an LSTM Word Language Model ================================================================== **Author**: `James Reed `_ **Edited by**: `Seth Weidman `_ Introduction ------------ Quantization involves converting the weights and activations of your model from float to int, which can result in smaller model size and faster inference with only a small hit to accuracy. In this tutorial, we will apply the easiest form of quantization - `dynamic quantization `_ - to an LSTM-based next word-prediction model, closely following the `word language model `_ from the PyTorch examples. .. GENERATED FROM PYTHON SOURCE LINES 22-32 .. code-block:: default # imports import os from io import open import time import torch import torch.nn as nn import torch.nn.functional as F .. GENERATED FROM PYTHON SOURCE LINES 33-39 1. Define the model ------------------- Here we define the LSTM model architecture, following the `model `_ from the word language model example. .. GENERATED FROM PYTHON SOURCE LINES 39-73 .. code-block:: default class LSTMModel(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5): super(LSTMModel, self).__init__() self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) self.decoder = nn.Linear(nhid, ntoken) self.init_weights() self.nhid = nhid self.nlayers = nlayers def init_weights(self): initrange = 0.1 self.encoder.weight.data.uniform_(-initrange, initrange) self.decoder.bias.data.zero_() self.decoder.weight.data.uniform_(-initrange, initrange) def forward(self, input, hidden): emb = self.drop(self.encoder(input)) output, hidden = self.rnn(emb, hidden) output = self.drop(output) decoded = self.decoder(output) return decoded, hidden def init_hidden(self, bsz): weight = next(self.parameters()) return (weight.new_zeros(self.nlayers, bsz, self.nhid), weight.new_zeros(self.nlayers, bsz, self.nhid)) .. GENERATED FROM PYTHON SOURCE LINES 74-82 2. Load in the text data ------------------------ Next, we load the `Wikitext-2 dataset `_ into a `Corpus`, again following the `preprocessing `_ from the word language model example. .. GENERATED FROM PYTHON SOURCE LINES 82-132 .. code-block:: default class Dictionary(object): def __init__(self): self.word2idx = {} self.idx2word = [] def add_word(self, word): if word not in self.word2idx: self.idx2word.append(word) self.word2idx[word] = len(self.idx2word) - 1 return self.word2idx[word] def __len__(self): return len(self.idx2word) class Corpus(object): def __init__(self, path): self.dictionary = Dictionary() self.train = self.tokenize(os.path.join(path, 'train.txt')) self.valid = self.tokenize(os.path.join(path, 'valid.txt')) self.test = self.tokenize(os.path.join(path, 'test.txt')) def tokenize(self, path): """Tokenizes a text file.""" assert os.path.exists(path) # Add words to the dictionary with open(path, 'r', encoding="utf8") as f: for line in f: words = line.split() + [''] for word in words: self.dictionary.add_word(word) # Tokenize file content with open(path, 'r', encoding="utf8") as f: idss = [] for line in f: words = line.split() + [''] ids = [] for word in words: ids.append(self.dictionary.word2idx[word]) idss.append(torch.tensor(ids).type(torch.int64)) ids = torch.cat(idss) return ids model_data_filepath = 'data/' corpus = Corpus(model_data_filepath + 'wikitext-2') .. GENERATED FROM PYTHON SOURCE LINES 133-141 3. Load the pretrained model ----------------------------- This is a tutorial on dynamic quantization, a quantization technique that is applied after a model has been trained. Therefore, we'll simply load some pretrained weights into this model architecture; these weights were obtained by training for five epochs using the default settings in the word language model example. .. GENERATED FROM PYTHON SOURCE LINES 141-162 .. code-block:: default ntokens = len(corpus.dictionary) model = LSTMModel( ntoken = ntokens, ninp = 512, nhid = 256, nlayers = 5, ) model.load_state_dict( torch.load( model_data_filepath + 'word_language_model_quantize.pth', map_location=torch.device('cpu'), weights_only=True ) ) model.eval() print(model) .. rst-class:: sphx-glr-script-out .. code-block:: none LSTMModel( (drop): Dropout(p=0.5, inplace=False) (encoder): Embedding(33278, 512) (rnn): LSTM(512, 256, num_layers=5, dropout=0.5) (decoder): Linear(in_features=256, out_features=33278, bias=True) ) .. GENERATED FROM PYTHON SOURCE LINES 163-166 Now let's generate some text to ensure that the pretrained model is working properly - similarly to before, we follow `here `_ .. GENERATED FROM PYTHON SOURCE LINES 166-191 .. code-block:: default input_ = torch.randint(ntokens, (1, 1), dtype=torch.long) hidden = model.init_hidden(1) temperature = 1.0 num_words = 1000 with open(model_data_filepath + 'out.txt', 'w') as outf: with torch.no_grad(): # no tracking history for i in range(num_words): output, hidden = model(input_, hidden) word_weights = output.squeeze().div(temperature).exp().cpu() word_idx = torch.multinomial(word_weights, 1)[0] input_.fill_(word_idx) word = corpus.dictionary.idx2word[word_idx] outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' ')) if i % 100 == 0: print('| Generated {}/{} words'.format(i, 1000)) with open(model_data_filepath + 'out.txt', 'r') as outf: all_output = outf.read() print(all_output) .. rst-class:: sphx-glr-script-out .. code-block:: none | Generated 0/1000 words | Generated 100/1000 words | Generated 200/1000 words | Generated 300/1000 words | Generated 400/1000 words | Generated 500/1000 words | Generated 600/1000 words | Generated 700/1000 words | Generated 800/1000 words | Generated 900/1000 words b'or' b'get' b'a' b'considerable' b'arrangement' b'of' b'iron' b'cell' b'past' b'Greater' b'Exploration' b'but' b'clinics' b'reduced' b'advertisements' b'are' b'eucalypt' b'.' b'Among' b'the' b'reason' b'of' b'Gruffudd' b'around' b'people' b'in' b'the' b'language' b',' b'players' b'were' b'suggested' b'"' b'there' b"'s" b'' b'and' b'not' b'' b'for' b'looped' b'football' b'(' b'up' b'to' b'mitosis' b'medius' b'"' b'.' b'This' b'chapel' b'is' b'gems' b')' b',' b'or' b'other' b'limit' b'never' b'need' b'and' b',' b'other' b'all' b'being' b'installed' b'so' b'.' b'Another' b'casual' b'simplistic' b'addiction' b'won' b'administratively' b'nests' b',' b'developing' b'white' b'foundations' b',' b'commence' b'the' b'stallions' b',' b'although' b'' b'filmography' b'took' b'place' b'with' b'both' b'more' b'information' b'.' b'These' b'effect' b'+' b'averaged' b'mammals' b'from' b'westbound' b'identification' b'for' b'what' b'he' b'"' b'go' b'UK' b'for' b'Earth' b'in' b'' b'game' b'after' b'in' b'1971' b'and' b'Green' b'The' b'planet' b'who' b"'ll" b'request' b'the' b'man' b'of' b'prey' b'behind' b'this' b'kind' b'of' b'singing' b',' b'the' b'people' b'Museum' b'and' b'have' b'ball' b'exempt' b'material' b'by' b'Suwa' b'if' b'other' b'everybody' b'like' b'atmosphere' b'in' b'Anzac' b'Hero' b',' b'they' b'have' b'very' b'correct' b',' b'and' b'may' b"'t" b'endorse' b'all' b'in' b'a' b'planet' b',' b'a' b'passion' b'that' b'amassed' b'the' b'mentioned' b'.' b'"' b'Further' b'seasons' b'of' b'sip' b'and' b'dogs' b'have' b'come' b'in' b'the' b'series' b'while' b'the' b'willingness' b'at' b'346' b'lectures' b'represented' b'during' b'' b'and' b'console' b',' b'' b',' b'enzyme' b'Worcester' b',' b'and' b'regain' b'their' b'three' b'opponents' b'kiss' b'by' b'hire' b'two' b'million' b'chicks' b'of' b'foremast' b'.' b'Then' b',' b'it' b'was' b'portrayed' b'as' b'an' b'connection' b'between' b'scientists' b'.' b'' b'Because' b'of' b'his' b'Constance' b',' b'proper' b'sales' b'of' b'their' b'submission' b'is' b'much' b'popular' b'.' b'If' b'what' b'saw' b'consistently' b'European' b'updated' b'manpower' b'of' b'55' b'million' b'minor' b'km' b'(' b'160' b'mi' b'/' b'h' b')' b',' b'it' b'is' b'possible' b'that' b'I' b'are' b'designed' b'with' b'the' b'Irish' b',' b'and' b'that' b'advancing' b'forced' b'the' b'same' b'be' b'except' b'from' b'' b'.' b'As' b'of' b'2013' b'the' b'population' b'was' b'discovered' b'to' b'be' b'discovered' b'Nation' b'below' b'animal' b'a' b'year' b'.' b'Eventually' b',' b'in' b'part' b',' b'the' b'chicks' b'had' b'also' b'completed' b'an' b'work' b'a' b'act' b'when' b'it' b'released' b'them' b'to' b'rule' b'more' b'active' b'to' b'which' b'could' b'charged' b'only' b'to' b'Perry' b',' b'to' b'prepare' b'for' b'Edwin' b'Howard' b'more' b'.' b'It' b'is' b'also' b'only' b'destroyed' b'in' b'1992' b'.' b'Dollo' b'take' b'to' b'New' b'York' b',' b'a' b'important' b'figure' b'"' b'' b'"' b',' b'between' b'a' b'homemaker' b'which' b'included' b'that' b'"' b'[' b'empire' b']' b'sinner' b'behind' b'each' b'modest' b'interests' b'"' b'.' b'Its' b'farm' b'expenses' b'has' b'adapted' b'internal' b'sideways' b'substances' b',' b'so' b',' b'in' b'all' b'his' b'blood' b'countries' b"'" b'maze' b'of' b'' b'management' b',' b'travel' b'their' b'practical' b'status' b'.' b'It' b'requirements' b'Stand' b'Garrett' b'Diggle' b'appears' b'dives' b'due' b'to' b'earthworms' b'.' b'Passing' b'in' b'lumber' b',' b'George' b'S.' b'Chinian' b'(' b'reset' b')' b'may' b'be' b'related' b'as' b'Giger' b"'s" b'animals' b':' b'"' b'And' b'nitrate' b',' b'"' b'and' b'Remember' b'a' b'plated' b'Canadian' b'body' b',' b'the' b'Andrzej' b'is' b'spore' b'odd' b'and' b'will' b'be' b'even' b'close' b'to' b'those' b'or' b'territories' b'where' b'regurgitated' b':' b'a' b'' b'story' b'never' b'is' b'active' b'from' b'water' b',' b'and' b'that' b'that' b'it' b'is' b'possible' b'"' b'distance' b'enough' b'after' b'drains' b'.' b'"' b'' b'Secretary' b'constituted' b'that' b'God' b'publishing' b',' b'"' b'seen' b'' b'care' b'red' b'constantly' b',' b'"' b'both' b'replete' b'deities' b'as' b'a' b'whole' b',' b'and' b'that' b'molecules' b'do' b'so' b'Lourdes' b'and' b'they' b"'re" b'trapped' b'.' b'In' b'cotton' b',' b'thirty' b'@-@' b'alarm' b'individuals' b'were' b'portions' b'of' b'theropod' b'.' b'These' b'm' b'(' b'rich' b'who' b'have' b'earlier' b'range' b'as' b'it' b'picked' b'down' b'or' b'were' b'completing' b'or' b'spending' b'his' b'association' b';' b'I' b'can' b'possibly' b'begin' b'of' b'up' b'another' b'DNA' b'except' b'.' b'"' b'Where' b'the' b'numbers' b'of' b'some' b'males' b'as' b'a' b'sometimes' b'record' b',' b'it' b'overruled' b'little' b'the' b'large' b'fast' b'descriptions' b'of' b'human' b'trash' b'.' b'"' b'Ordinary' b'' b'described' b'a' b'style' b'of' b'dancing' b'statements' b'\xe2\x80\x94' b'harassment' b',' b'and' b'animal' b'Germain' b'documents' b'researching' b',' b'embryos' b'who' b'opined' b'that' b'online' b',' b'' b',' b'or' b'exposed' b'whalebone' b';' b'and' b'in' b'527' b'top' b',' b'greater' b',' b'rosary' b'and' b'spell' b'by' b'kitsune' b'to' b'higher' b'network' b'Montb\xc3\xa9liard' b'holds' b'other' b'kinds' b'of' b'individuals' b'.' b'' b'' b'(' b'M.' b'Bandicoot' b')' b'have' b'a' b'Younger' b'average' b'at' b'Garth' b'Akka' b'they' b'can' b'develop' b'his' b'red' b'beak' b'.' b'' b'Other' b'descendant' b'thighs' b'white' b',' b'frame' b',' b'and' b'births' b',' b'by' b'the' b'head' b'unless' b'young' b'emission' b'on' b'trees' b'1870s' b'and' b'' b'influences' b'.' b'A' b'female' b'to' b'strike' b'Irish' b'scales' b'can' b'rule' b'the' b'young' b'as' b'fox' b',' b'but' b'in' b'878' b'magnetic' b'habit' b'of' b'Ole' b'cannot' b'be' b'tolerant' b'into' b'' b',' b'according' b'to' b'them' b'during' b'India' b'.' b'The' b'' b'1.c4' b'egg' b'de' b'Enuff' b'bequeathed' b'every' b'research' b'house' b'of' b'nest' b'Dania' b'TV' b',' b'drops' b'energy' b'the' b'pursuers' b',' b'extending' b'O' b"'Malley" b',' b'preferred' b'briefly' b'only' b'well' b'feed' b'.' b'\xe2\x80\x9d' b'and' b'Franz' b'M.' b'' b'note' b'usually' b'attended' b'to' b'having' b'their' b'efforts' b':' b'he' b'grew' b'at' b'some' b'other' b'species' b'ahead' b'of' b'the' b'female' b'recreation' b'.' b'(' b'captivity' b'occur' b'off' b'the' b'Labour' b'Serapion' b'or' b'Islais' b'Mode' b')' b'\xe2\x80\x93' b'another' b'pitch' b'by' b'lock' b'.' b'He' b'' b'in' b'Inshore' b',' b'before' b'he' b'is' b'broken' b'.' b'With' b'her' b'death' b',' b'the' b'kakapo' b'has' b'a' b'wide' b'goal' b'(' b'under' b'his' b'role' b'over' b'her' b'privileged' b',' b'as' b'it' b'has' b'never' b'repeated' b'900' b'points' b'years' b'in' b'the' b'shooter' b',' b'with' b'controlled' b'very' b'energy' b'for' b'2' b'to' b'6' b'Pallas' b'.' b'(' b'O' b'One' b'Sumino' b'is' b'a' b'object' b')' b',' b'a' b'feet' b'into' b'rocking' b'between' b'22' b'and' b'2' b'perennial' b'during' b'.' b'Zygoballus' b'letters' b'stopped' b'into' b'three' b'men' b'from' b'disagreements' b'.' b'There' b'are' b'no' b'widespread' b'crop' b',' b'with' b'bad' b'@-@' b'pound' b'' b',' b'sea' b'fungi' b',' b'and' b'anti' b'@-@' b'green' b'starlings' b',' b'tumor' b'or' b'locally' b'material' b'.' b'' b'' b'=' b'=' b'=' b'Similar' b'works' b'=' b'=' b'=' b'' b'' b'Art' b'riding' b'610' b'proteins' b'and' b'women' b',' b'when' b'lead' b'winds' b'and' b'asks' b'males' b'generally' b'in' b'the' b'pigment' b'and' b'heaviest' b'Unwilling' b'.' b'Their' b'blast' b'return' b'for' b'food' b'or' b'laser' b'are' b'usually' b'ago' b'and' b'small' b'rough' b',' b'while' b'should' b'be' b'honest' b'.' b'Their' b'nuclei' b'also' b'inflicted' b'a' b'eye' b'7' b'@.@' b'8' b'to' b'5' b'thousand' b'flowers' b',' b'one' b'may' b'have' b'continually' b'ruin' b'them' b'ago' b',' b'and' b'they' b'feed' b'up' b'to' b'their' b'food' b'.' b'At' b'a' b'value' b'they' b'go' b'on' b'that' b'existence' b'having' b'said' b'with' b'some' b'types' b'of' b'other' b'types' b'of' b'dark' b'roosts' b'.' b'Vishnuvardhana' b'becomes' b'happening' b'in' b'the' b'Dubliners' b',' b'and' b'there' b'are' b'likely' b'they' b'bear' b'its' .. GENERATED FROM PYTHON SOURCE LINES 192-197 It's no GPT-2, but it looks like the model has started to learn the structure of language! We're almost ready to demonstrate dynamic quantization. We just need to define a few more helper functions: .. GENERATED FROM PYTHON SOURCE LINES 197-242 .. code-block:: default bptt = 25 criterion = nn.CrossEntropyLoss() eval_batch_size = 1 # create test data set def batchify(data, bsz): # Work out how cleanly we can divide the dataset into ``bsz`` parts. nbatch = data.size(0) // bsz # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data.narrow(0, 0, nbatch * bsz) # Evenly divide the data across the ``bsz`` batches. return data.view(bsz, -1).t().contiguous() test_data = batchify(corpus.test, eval_batch_size) # Evaluation functions def get_batch(source, i): seq_len = min(bptt, len(source) - 1 - i) data = source[i:i+seq_len] target = source[i+1:i+1+seq_len].reshape(-1) return data, target def repackage_hidden(h): """Wraps hidden states in new Tensors, to detach them from their history.""" if isinstance(h, torch.Tensor): return h.detach() else: return tuple(repackage_hidden(v) for v in h) def evaluate(model_, data_source): # Turn on evaluation mode which disables dropout. model_.eval() total_loss = 0. hidden = model_.init_hidden(eval_batch_size) with torch.no_grad(): for i in range(0, data_source.size(0) - 1, bptt): data, targets = get_batch(data_source, i) output, hidden = model_(data, hidden) hidden = repackage_hidden(hidden) output_flat = output.view(-1, ntokens) total_loss += len(data) * criterion(output_flat, targets).item() return total_loss / (len(data_source) - 1) .. GENERATED FROM PYTHON SOURCE LINES 243-252 4. Test dynamic quantization ---------------------------- Finally, we can call ``torch.quantization.quantize_dynamic`` on the model! Specifically, - We specify that we want the ``nn.LSTM`` and ``nn.Linear`` modules in our model to be quantized - We specify that we want weights to be converted to ``int8`` values .. GENERATED FROM PYTHON SOURCE LINES 252-260 .. code-block:: default import torch.quantization quantized_model = torch.quantization.quantize_dynamic( model, {nn.LSTM, nn.Linear}, dtype=torch.qint8 ) print(quantized_model) .. rst-class:: sphx-glr-script-out .. code-block:: none LSTMModel( (drop): Dropout(p=0.5, inplace=False) (encoder): Embedding(33278, 512) (rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5) (decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine) ) .. GENERATED FROM PYTHON SOURCE LINES 261-263 The model looks the same; how has this benefited us? First, we see a significant reduction in model size: .. GENERATED FROM PYTHON SOURCE LINES 263-272 .. code-block:: default def print_size_of_model(model): torch.save(model.state_dict(), "temp.p") print('Size (MB):', os.path.getsize("temp.p")/1e6) os.remove('temp.p') print_size_of_model(model) print_size_of_model(quantized_model) .. rst-class:: sphx-glr-script-out .. code-block:: none Size (MB): 113.944455 Size (MB): 79.738939 .. GENERATED FROM PYTHON SOURCE LINES 273-277 Second, we see faster inference time, with no difference in evaluation loss: Note: we set the number of threads to one for single threaded comparison, since quantized models run single threaded. .. GENERATED FROM PYTHON SOURCE LINES 277-289 .. code-block:: default torch.set_num_threads(1) def time_model_evaluation(model, test_data): s = time.time() loss = evaluate(model, test_data) elapsed = time.time() - s print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed)) time_model_evaluation(model, test_data) time_model_evaluation(quantized_model, test_data) .. rst-class:: sphx-glr-script-out .. code-block:: none loss: 5.167 elapsed time (seconds): 208.9 loss: 5.168 elapsed time (seconds): 114.5 .. GENERATED FROM PYTHON SOURCE LINES 290-301 Running this locally on a MacBook Pro, without quantization, inference takes about 200 seconds, and with quantization it takes just about 100 seconds. Conclusion ---------- Dynamic quantization can be an easy way to reduce model size while only having a limited effect on accuracy. Thanks for reading! As always, we welcome any feedback, so please create an issue `here `_ if you have any. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 5 minutes 32.990 seconds) .. _sphx_glr_download_advanced_dynamic_quantization_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: dynamic_quantization_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: dynamic_quantization_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_