Saving and Loading Weights in PyTorch Lightning
Last Updated :
24 Sep, 2024
In Machine learning models, it is important to save and load weights efficiently. This helps us preserve the state of our model during training, so we can resume later without starting from scratch. In this article, we are going to discuss how to save and load weights in PyTorch Lightning. PyTorch Lightning is an easy-to-use library that simplifies PyTorch.
We will cover the steps involved in saving and loading weights, various configurations, and best practices for working with models in PyTorch Lightning.
Why Saving and Loading Weights is Important?
let's first understand, PyTorch Lightning is a lightweight wrapper around PyTorch that helps us organize our code and reduce boilerplate. It makes training models simpler and more efficient by providing built-in features for saving and loading weights, managing checkpoints, and many more.
Saving and loading model weights is essential for the following reasons:
- Checkpointing: Regularly saving model weights ensures that you can resume training from the last saved state in case of interruptions.
- Inference: Once a model is trained, you can save its weights to disk and load them later for inference without having to retrain the model.
- Model Versioning: It allows you to keep different versions of your model with varying hyperparameters and architectures.
- Transfer Learning: Loading pre-trained weights enables fine-tuning models for different tasks.
Checkpoints in PyTorch Lightning
PyTorch Lightning provides built-in support for saving and loading model checkpoints. These checkpoints store more than just the model weights—they also include information about the optimizer, learning rate scheduler, and current epoch, making it easy to resume training seamlessly.
A checkpoint is essentially a snapshot of our model at a specific point during training. It saves not only the model's weights but also things such as:
- Current training epoch
- Optimizer states
- Learning rate scheduler states
- Hyperparameters used during training
Saving Model Weights in PyTorch Lightning
The ModelCheckpoint
callback in PyTorch Lightning is designed to save the model's state at specified intervals or under certain conditions such as when the validation accuracy improves.
Install PyTorch Lightning: In our Google Colab or Jupyter notebook, run the following command to install the library:
!pip install pytorch-lightning
Step 1: Import Required Libraries
First, we will import some required libraries:
- PyTorch for building the neural network and managing data.
- PyTorch Lightning to streamline the training process.
- ModelCheckpoint to save the model automatically based on the loss during training.
Python
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint
Step 2: Create a Sample Dataset
We will generate a simple dataset where the target y
follows the formula y = 2x + 1
. PyTorch's TensorDataset
will hold the features x
and labels y
. The DataLoader
will handle batching the data during training.
Python
x = torch.rand(100, 1) # Random 100 data points
y = 2 * x + 1 # Linear relationship
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=10) # Batching the data
Step 3: Define the Model
We define a very simple neural network with just one linear layer using PyTorch's nn.Linear
. This is a basic linear regression model, which tries to learn the relationship between input x
and output y
. In this model training_step d
efines the training loop for one batch, computing the Mean Squared Error (MSE) loss and configure_optimizers s
pecifies the optimizer for the model parameters, in this case, Stochastic Gradient Descent (SGD).
Python
class SimpleModel(pl.LightningModule):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(1, 1) # Linear layer with 1 input, 1 output
def forward(self, x):
return self.linear(x) # Forward pass
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x) # Model predictions
loss = nn.MSELoss()(y_hat, y) # Compute loss
self.log('train_loss', loss) # Log the training loss
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01) # Optimizer
Step 4: Add ModelCheckpoint Callback (Model Saving)
We use PyTorch Lightning’s ModelCheckpoint
callback to save the best model during training. The ModelCheckpoint
saves the model every time a new minimum training loss is found. The ModelCheckpoint
callback is used to automatically save the model's weights during training. In your code:
monitor='train_loss'
: It monitors the training loss.filename='best_model'
: The model is saved with this filename.save_top_k=1
: Only the best model (in terms of the lowest training loss) will be saved.mode='min'
: The checkpoint is saved when the monitored value (train_loss
) decreases.dirpath='checkpoints/'
: Specifies the directory where the checkpoint is saved.
During training, the model's best weights are saved in the checkpoints/best_model.ckpt
file.
Python
checkpoint_callback = ModelCheckpoint(
monitor='train_loss',
filename='best_model',
save_top_k=1,
mode='min',
dirpath='checkpoints/'
)
Step 5: Initialize the Trainer
The pl.Trainer
is the core of PyTorch Lightning. Here, we pass the checkpoint_callback
and set max_epochs=1
0.
Python
trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=10)
Step 6: Train the Model
Now we train the model using the trainer.fit()
method. The model and data loader are passed as arguments, and training begins.
Python
trainer.fit(SimpleModel(), dataloader)
During training, the model's weights will be saved in the checkpoints/
directory, and the checkpoint with the best training loss will be saved as best_model.ckpt
.
Loading Model Weights in PyTorch Lightning
After training is complete, we can load the best model from the checkpoint. This allows us to resume training, fine-tune the model, or use it for inference.
Python
loaded_model = SimpleModel.load_from_checkpoint('checkpoints/best_model.ckpt')
Example: Saving and Loading Weights of a Simple Model
Now, the complete code which shows how to build, train, and save a simple linear regression model using PyTorch Lightning.
Python
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint
# Sample dataset
x = torch.rand(100, 1)
y = 2 * x + 1
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=10)
# Define a simple model
class SimpleModel(pl.LightningModule):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.MSELoss()(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)
# Create a ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
monitor='train_loss',
filename='best_model',
save_top_k=1,
mode='min',
dirpath='checkpoints/'
)
# Initialize the trainer with the checkpoint callback
trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=10)
# Train the model
trainer.fit(SimpleModel(), dataloader)
Output:
INFO:pytorch_lightning.callbacks.model_summary:
| Name | Type | Params | Mode
------------------------------------------
0 | linear | Linear | 2 | train
------------------------------------------
2 Trainable params
0 Non-trainable params
2 Total params
0.000 Total estimated model params size (MB)
1 Modules in train mode
0 Modules in eval mode
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 9: 100%
 10/10 [00:00<00:00, 205.39it/s, v_num=0]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
Python
After training, load the model
loaded_model = SimpleModel.load_from_checkpoint('checkpoints/best_model.ckpt')
Output:
Saving and Loading WeightsBest Practices for Saving and Loading Weights
1. Monitor the Right Metric
It’s important to monitor the most relevant metric for your task when saving checkpoints. For example, for a classification task, you might want to monitor validation accuracy (val_acc
), while for a regression task, you may want to track validation loss (val_loss
).
2. Use save_top_k
Wisely
The save_top_k
argument in the ModelCheckpoint
callback allows you to save only the best performing models. This helps in reducing storage overhead by not saving every checkpoint.
3. Use GPU/CPU Flexibility
When saving and loading models, PyTorch Lightning takes care of moving your model between CPUs and GPUs automatically. This means you can train your model on a GPU and load it for inference on a CPU without any changes.
# Load the model on a specific device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyModel.load_from_checkpoint("checkpoint.ckpt", map_location=device)
4. Checkpoint Naming Conventions
Use meaningful naming conventions when saving checkpoints. Including metrics such as epoch and validation loss/accuracy in the checkpoint filename helps you to identify the best models easily.
5. Resume Training with Frozen Weights
If you want to resume training with part of the model frozen (e.g., for fine-tuning), you can achieve this by manually setting the requires_grad
flag of the layers you want to freeze.
# Load model and freeze some layers
model = MyModel.load_from_checkpoint("checkpoint.ckpt")
for param in model.feature_extractor.parameters():
param.requires_grad = False
Common Pitfalls
- Forgetting to Save the Optimizer State: When saving a model for resuming training, ensure that you save the optimizer’s state. PyTorch Lightning’s checkpointing system automatically saves the optimizer state, but if you are manually handling checkpoints, you need to include the optimizer state in the checkpoint.
- Overwriting Checkpoints: If you don’t specify a unique filename or directory for each checkpoint, you might end up overwriting previously saved models. To avoid this, use dynamic file names based on epochs and metrics.
- Loading Weights into a Different Model Architecture: If you attempt to load weights into a model that does not match the architecture of the saved model, PyTorch will throw an error. Always ensure that the architecture is identical when loading weights.
Conclusion
In this article, we have seen a basic workflow for training a model using PyTorch Lightning and ModelCheckpoint to save the best-performing model. It automates many aspects of training, including managing the training loop and saving model checkpoints which makes it easier to focus on building and fine-tuning the model.