Save and Load Models in PyTorch
Last Updated :
28 Apr, 2025
It often happens that we need to use the already-trained models to perform some operations in our development environment. In this case, would you create the model again and again? Or, you will save the model somewhere else and load it as per the requirement. You would definitely choose the second option. So in this article, we will see how to implement the concept of saving and loading the models using PyTorch.
What is PyTorch?
PyTorch is an open-source Machine Learning Library that works on the dynamic computation graph. In the static computation approach, the models are predefined before the execution. But in dynamic computation which PyTorch follows, the structure of the graph in the Neural Network can change during the execution based on the input data. Hence, It allows to creation and training the Neural Networks to extract hidden patterns from the data.
You might think what a Neural Network is. So in simple words, a Neural Network is a collection of layers containing Nodes. These layers are interconnected with each other in which one Node processes the data and passes it to the other Node. Hence, the entire Neural Network learns and extracts the insights from the data.
Stepwise Guide to Save and Load Models in PyTorch
Now, we will see how to create a Model using the PyTorch.
Creating Model in PyTorch
To save and load the model, we will first create a Deep-Learning Model for the image classification. This model will classify the images of the handwritten digits from the MNIST Dataset. The below code implements the Convolutional Neural Network for image classification. The data is loaded and transformed into PyTorch Sensors, which are like containers to store the data.
The following code shows the creation of the PyTorch Model.
Importing Necessary Libraries
Python3
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
Data Transformation
The given code defines a transformation pipeline using torchvision.transforms.Compose
for preprocessing image data before feeding it into a PyTorch model.
transforms.ToTensor()
: Converts the input image (assumed to be in PIL Image format) to a PyTorch tensor. It converts the image data type to torch.FloatTensor
and scales the pixel values to the range [0.0, 1.0].
transforms.Normalize((0.5,), (0.5,))
: Normalizes the tensor image with mean and standard deviation. The provided mean and standard deviation values (0.5,)
and (0.5,)
respectively are used to normalize each channel of the input tensor. This transformation normalizes the tensor values to be in the range [-1.0, 1.0].
Python3
# Define transformation to apply to the data
data_transform = transforms.Compose([
transforms.ToTensor(), # Convert images to PyTorch tensors
transforms.Normalize((0.5,), (0.5,)) # Normalize the pixel values to range [-1, 1]
])
# Download MNIST dataset and apply the transformation
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=data_transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=data_transform, download=True)
# Define data loaders to load the data in batches during training and testing
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
Defining neural network architecture
- Class Definition: The code defines a class
SimpleCNN
that inherits from nn.Module
, which is the base class for all neural network modules in PyTorch. This class represents a simple convolutional neural network (CNN). - Initialization: In the
__init__
method, the code defines the layers of the CNN. It includes two convolutional layers (conv1_layer
and conv2_layer
) with specified kernel sizes and padding, and two fully connected layers (fc1_layer
and fc2_layer
) with specified input and output sizes. - Forward Pass: The
forward
method defines the forward pass of the network. It applies a ReLU activation function after each convolutional layer and uses max pooling with a kernel size of 2 and stride of 2 to downsample the feature maps. The output of the second convolutional layer is flattened before being passed to the fully connected layers. - View Operation: The
view
operation reshapes the output of the second convolutional layer to be compatible with the input size of the first fully connected layer. The -1
argument in view
indicates that the size of that dimension should be inferred based on the other dimensions. - Model Instance: Finally, an instance of the
SimpleCNN
class is created and assigned to the variable cnn_model
. This instance represents the actual neural network that can be trained and used for inference.
Python3
# Here we are adding convolution layer and fully connected layers in neural network
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1_layer = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.conv2_layer = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.fc1_layer = nn.Linear(32 * 7 * 7, 128)
self.fc2_layer = nn.Linear(128, 10)
# Adding ReLU Activation function Max Pooling Layer
def forward(self, inputs):
new_input = torch.relu(self.conv1_layer(inputs))
new_input = torch.max_pool2d(new_input, kernel_size=2, stride=2)
new_input = torch.relu(self.conv2_layer(new_input))
new_input = torch.max_pool2d(new_input, kernel_size=2, stride=2)
new_input = new_input.view(-1, 32 * 7 * 7)
new_input = torch.relu(self.fc1_layer(new_input))
new_input = self.fc2_layer(new_input)
return new_input
# Creating Model Instance
cnn_model = SimpleCNN()
Loss Function and Optimizer
- Loss Function:
nn.CrossEntropyLoss()
is used as the loss function. This loss function is commonly used for classification problems with multiple classes. It calculates the cross-entropy loss between the predicted class probabilities and the actual class labels. - Optimizer:
optim.Adam
is used as the optimizer. Adam is a popular optimization algorithm that computes adaptive learning rates for each parameter. It is well-suited for training deep neural networks. The optimizer is initialized with the parameters of the cnn_model
and a learning rate of 0.001.
Python3
# Define loss function and optimizer
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=0.001)
Training the model
The code implements the following steps:
- Outer Loop (Epochs): The code iterates over 5 epochs using a
for
loop. An epoch is a single pass through the entire dataset. - Inner Loop (Batches): Within each epoch, the code iterates over batches of data using
train_loader
, which presumably contains batches of input data (inputs
) and their corresponding labels (labels
). - Zero Gradients: Before the backward pass (
loss.backward()
), optimizer.zero_grad()
is called to zero out the gradients of the model parameters. This is necessary because PyTorch accumulates gradients by default. - Forward and Backward Pass:
outputs = cnn_model(inputs)
performs the forward pass, where the model processes the input data to generate predictions (outputs
).loss = loss_func(outputs, labels)
calculates the loss between the predicted outputs and the actual labels.loss.backward()
computes the gradients of the loss with respect to the model parameters, enabling backpropagation.optimizer.step()
updates the model parameters based on the computed gradients, using the optimization algorithm (Adam in this case) to adjust the weights.
- Loss Calculation: Within the inner loop,
running_loss
accumulates the total loss across batches. At the end of each epoch, the average loss per batch is printed to monitor the training progress.
Python3
# Train model
for epoch in range(5): # Train for 5 epochs
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad() # Zero the gradients
outputs = cnn_model(inputs) # Forward pass
loss = loss_func(outputs, labels) # Calculate the loss
loss.backward() # Backward pass
optimizer.step() # Update weights
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
Output:
Epoch 1, Loss: 0.22154594235159933
Epoch 2, Loss: 0.05766747533348697
Epoch 3, Loss: 0.04144403319505514
Epoch 4, Loss: 0.029859573355312946
Epoch 5, Loss: 0.024109310584392515
Testing The Model
Python
# Test model
correct_predictions = 0
total_samples = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = cnn_model(inputs)
_, predicted_labels = torch.max(outputs.data, 1)
total_samples += labels.size(0)
correct_predictions += (predicted_labels == labels).sum().item()
print(f"Accuracy of test set: {100 * correct_predictions / total_samples}%")
Output:
Accuracy of test set: 99.16%
Saving and Loading Model
Method 1: Using torch.save() and torch.load()
The following code shows method to save and load the model using the built-in function provided by the torch module. The torch.save() method directly saves model object into the file and the torch.load() loads the model back into the memory.
Python
# Save the model
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
# Load the model
loaded_model = SimpleCNN()
loaded_model.load_state_dict(torch.load('cnn_model.pth'))
# Set the model to evaluation mode
loaded_model.eval()
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Method 2: Using model.state_dict()
Now, let us see another way to save and load the model using the state_dict() method. This method stores the parameters of the created model. When the model is loaded, a new model with the same architecture is created. Then, the parameters of the new model are replaced with the stored parameters. Since only parameters are stored, this method is memory efficient. The following code snippet illustrates this method.
Python
# Saving the model
torch.save(cnn_model.state_dict(), 'cnn_model.pth')
# Loading the model
loaded_model = SimpleCNN()
loaded_model.load_state_dict(torch.load('cnn_model.pth'))
print(loaded_model)
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Method 3: Saving and Loading using the Checkpoints
The checkpoints method saves the model by creating a dictionary that contains all the necessary information like model state_dict, optimizer state_dict, current epoch, loss, etc. And, to load the model, the checkpoint file is loaded to retrieve the information. This method is demonstrated as shown below:
Python
# Saving the model
checkpoint = {
'epoch': epoch,
'model_state_dict': cnn_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# you may add other information to add
}
torch.save(checkpoint, 'checkpoint.pth')
# Loading the model
checkpoint = torch.load('checkpoint.pth')
cnn_model = SimpleCNN()
cnn_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(cnn_model)
Output:
SimpleCNN(
(conv1_layer): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2_layer): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(fc1_layer): Linear(in_features=1568, out_features=128, bias=True)
(fc2_layer): Linear(in_features=128, out_features=10, bias=True)
)
Conclusion
There are various methods to save and load Models created using PyTorch Library. It has the torch.save() and torch.load() method to save and load the model object. On the other hand, the model.state_dict() provides the memory-efficient approach to save and load the models. In addition to this, if you want to store all the relevant information about the model in a dictionary, you can use the checkpoint file to store the model object and retrieve it from the memory. Hence, these various methods allow us to manage the models, and transfer the parameters and other information. All we need to understand is the memory constraints, information beyond just model parameters, and use-case scenarios so that we can select the right method.
Similar Reads
Deep Learning Tutorial Deep Learning is a subset of Artificial Intelligence (AI) that helps machines to learn from large datasets using multi-layered neural networks. It automatically finds patterns and makes predictions and eliminates the need for manual feature extraction. Deep Learning tutorial covers the basics to adv
5 min read
Deep Learning Basics
Introduction to Deep LearningDeep Learning is transforming the way machines understand, learn and interact with complex data. Deep learning mimics neural networks of the human brain, it enables computers to autonomously uncover patterns and make informed decisions from vast amounts of unstructured data. How Deep Learning Works?
7 min read
Artificial intelligence vs Machine Learning vs Deep LearningNowadays many misconceptions are there related to the words machine learning, deep learning, and artificial intelligence (AI), most people think all these things are the same whenever they hear the word AI, they directly relate that word to machine learning or vice versa, well yes, these things are
4 min read
Deep Learning Examples: Practical Applications in Real LifeDeep learning is a branch of artificial intelligence (AI) that uses algorithms inspired by how the human brain works. It helps computers learn from large amounts of data and make smart decisions. Deep learning is behind many technologies we use every day like voice assistants and medical tools.This
3 min read
Challenges in Deep LearningDeep learning, a branch of artificial intelligence, uses neural networks to analyze and learn from large datasets. It powers advancements in image recognition, natural language processing, and autonomous systems. Despite its impressive capabilities, deep learning is not without its challenges. It in
7 min read
Why Deep Learning is ImportantDeep learning has emerged as one of the most transformative technologies of our time, revolutionizing numerous fields from computer vision to natural language processing. Its significance extends far beyond just improving predictive accuracy; it has reshaped entire industries and opened up new possi
5 min read
Neural Networks Basics
What is a Neural Network?Neural networks are machine learning models that mimic the complex functions of the human brain. These models consist of interconnected nodes or neurons that process data, learn patterns and enable tasks such as pattern recognition and decision-making.In this article, we will explore the fundamental
11 min read
Types of Neural NetworksNeural networks are computational models that mimic the way biological neural networks in the human brain process information. They consist of layers of neurons that transform the input data into meaningful outputs through a series of mathematical operations. In this article, we are going to explore
7 min read
Layers in Artificial Neural Networks (ANN)In Artificial Neural Networks (ANNs), data flows from the input layer to the output layer through one or more hidden layers. Each layer consists of neurons that receive input, process it, and pass the output to the next layer. The layers work together to extract features, transform data, and make pr
4 min read
Activation functions in Neural NetworksWhile building a neural network, one key decision is selecting the Activation Function for both the hidden layer and the output layer. It is a mathematical function applied to the output of a neuron. It introduces non-linearity into the model, allowing the network to learn and represent complex patt
8 min read
Feedforward Neural NetworkFeedforward Neural Network (FNN) is a type of artificial neural network in which information flows in a single direction i.e from the input layer through hidden layers to the output layer without loops or feedback. It is mainly used for pattern recognition tasks like image and speech classification.
6 min read
Backpropagation in Neural NetworkBack Propagation is also known as "Backward Propagation of Errors" is a method used to train neural network . Its goal is to reduce the difference between the modelâs predicted output and the actual output by adjusting the weights and biases in the network.It works iteratively to adjust weights and
9 min read
Deep Learning Models
Deep Learning Frameworks
TensorFlow TutorialTensorFlow is an open-source machine-learning framework developed by Google. It is written in Python, making it accessible and easy to understand. It is designed to build and train machine learning (ML) and deep learning models. It is highly scalable for both research and production.It supports CPUs
2 min read
Keras TutorialKeras high-level neural networks APIs that provide easy and efficient design and training of deep learning models. It is built on top of powerful frameworks like TensorFlow, making it both highly flexible and accessible. Keras has a simple and user-friendly interface, making it ideal for both beginn
3 min read
PyTorch TutorialPyTorch is an open-source deep learning framework designed to simplify the process of building neural networks and machine learning models. With its dynamic computation graph, PyTorch allows developers to modify the networkâs behavior in real-time, making it an excellent choice for both beginners an
7 min read
Caffe : Deep Learning FrameworkCaffe (Convolutional Architecture for Fast Feature Embedding) is an open-source deep learning framework developed by the Berkeley Vision and Learning Center (BVLC) to assist developers in creating, training, testing, and deploying deep neural networks. It provides a valuable medium for enhancing com
8 min read
Apache MXNet: The Scalable and Flexible Deep Learning FrameworkIn the ever-evolving landscape of artificial intelligence and deep learning, selecting the right framework for building and deploying models is crucial for performance, scalability, and ease of development. Apache MXNet, an open-source deep learning framework, stands out by offering flexibility, sca
6 min read
Theano in PythonTheano is a Python library that allows us to evaluate mathematical operations including multi-dimensional arrays efficiently. It is mostly used in building Deep Learning Projects. Theano works way faster on the Graphics Processing Unit (GPU) rather than on the CPU. This article will help you to unde
4 min read
Model Evaluation
Deep Learning Projects