
# PyTorch Lightning DataModules

* **Author:** Lightning.ai
* **License:** CC BY-SA
* **Generated:** 2025-05-01T11:55:57.170344

This notebook will walk you through how to start using Datamodules. With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`. The most up-to-date documentation on datamodules can be found [here](https://lightning.ai/docs/pytorch/stable/data/datamodule.html).

---
Open in [Open In Colab{height="20px" width="117px"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/datamodules.ipynb)

Give us a ⭐ [on Github](https://www.github.com/Lightning-AI/lightning/)
| Check out [the documentation](https://lightning.ai/docs/)
| Join us [on Discord](https://discord.com/invite/tfXFetEZxv)

## Setup
This notebook requires some packages besides pytorch-lightning.

In [1]:
! pip install --quiet "pytorch-lightning >=2.0,<2.6" "torchmetrics>=1.0, <1.8" "numpy <3.0" "torchvision" "torch>=1.8.1, <2.8" "matplotlib"

[0m

## Introduction

First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`

In [2]:
import os

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import CIFAR10, MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

### Defining the LitMNISTModel

Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.

Unfortunately, we have hardcoded dataset-specific items within the model,
forever limiting it to working with MNIST Data. 😢

This is fine if you don't plan on training/evaluating your model on different datasets.
However, in many cases, this can become bothersome when you want to try out your architecture with different datasets.

In [3]:
class LitMNIST(pl.LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        # We hardcode dataset specific stuff here.
        self.data_dir = data_dir
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Build model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y, task="multiclass", num_classes=10)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=128)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=128)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=128)

### Training the ListMNIST Model

In [4]:
model = LitMNIST()
trainer = pl.Trainer(
    max_epochs=2,
    accelerator="auto",
    devices=1,
)
trainer.fit(model)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


  0%|          | 0.00/9.91M [00:00<?, ?B/s]

 29%|██▉       | 2.85M/9.91M [00:00<00:00, 28.3MB/s]

 78%|███████▊  | 7.77M/9.91M [00:00<00:00, 40.5MB/s]

100%|██████████| 9.91M/9.91M [00:00<00:00, 40.5MB/s]




  0%|          | 0.00/28.9k [00:00<?, ?B/s]

100%|██████████| 28.9k/28.9k [00:00<00:00, 3.08MB/s]




  0%|          | 0.00/1.65M [00:00<?, ?B/s]

 46%|████▌     | 754k/1.65M [00:00<00:00, 7.10MB/s]

100%|██████████| 1.65M/1.65M [00:00<00:00, 12.6MB/s]




  0%|          | 0.00/4.54k [00:00<?, ?B/s]

100%|██████████| 4.54k/4.54k [00:00<00:00, 6.62MB/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | Sequential | 55.1 K | train
---------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


## Using DataModules

DataModules are a way of decoupling data-related hooks from the `LightningModule
` so you can develop dataset agnostic models.

### Defining The MNISTDataModule

Let's go over each function in the class below and talk about what they're doing:

1. ```__init__```
    - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.
    - Defines a transform that will be applied across train, val, and test dataset splits.
    - Defines default `self.dims`.


2. ```prepare_data```
    - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.
    - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)

3. ```setup```
    - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).
    - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.
    - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.
    - **Note this runs across all GPUs and it *is* safe to make state assignments here**


4. ```x_dataloader```
    - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`

In [5]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = PATH_DATASETS):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

### Defining the dataset agnostic `LitModel`

Below, we define the same model as the `LitMNIST` model we made earlier.

However, this time our model has the freedom to use any input data that we'd like 🔥.

In [6]:
class LitModel(pl.LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        # We take in input dimensions as parameters and use those to dynamically build model.
        self.channels = channels
        self.width = width
        self.height = height
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y, task="multiclass", num_classes=10)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

### Training the `LitModel` using the `MNISTDataModule`

Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders.

In [7]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = pl.Trainer(
    max_epochs=3,
    accelerator="auto",
    devices=1,
)
# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, dm)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | Sequential | 55.1 K | train
---------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


### Defining the CIFAR10 DataModule

Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset.

In [8]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)

### Training the `LitModel` using the `CIFAR10DataModule`

Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.

The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data.

In [9]:
dm = CIFAR10DataModule()
model = LitModel(*dm.dims, dm.num_classes, hidden_size=256)
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="auto",
    devices=1,
)
trainer.fit(model, dm)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.


GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


  0%|          | 0.00/170M [00:00<?, ?B/s]

  0%|          | 98.3k/170M [00:00<03:15, 870kB/s]

  0%|          | 852k/170M [00:00<00:39, 4.25MB/s]

  3%|▎         | 4.59M/170M [00:00<00:09, 18.3MB/s]

  6%|▌         | 9.50M/170M [00:00<00:05, 29.9MB/s]

  8%|▊         | 14.5M/170M [00:00<00:04, 36.8MB/s]

 11%|█▏        | 19.5M/170M [00:00<00:03, 41.1MB/s]

 14%|█▍        | 24.5M/170M [00:00<00:03, 43.9MB/s]

 17%|█▋        | 29.4M/170M [00:00<00:03, 45.6MB/s]

 20%|██        | 34.3M/170M [00:00<00:02, 46.6MB/s]

 23%|██▎       | 39.3M/170M [00:01<00:02, 47.4MB/s]

 26%|██▌       | 44.2M/170M [00:01<00:02, 48.1MB/s]

 29%|██▉       | 49.2M/170M [00:01<00:02, 48.6MB/s]

 32%|███▏      | 54.2M/170M [00:01<00:02, 48.8MB/s]

 35%|███▍      | 59.1M/170M [00:01<00:02, 49.0MB/s]

 38%|███▊      | 64.2M/170M [00:01<00:02, 49.2MB/s]

 41%|████      | 69.1M/170M [00:01<00:02, 49.3MB/s]

 43%|████▎     | 74.2M/170M [00:01<00:01, 49.6MB/s]

 46%|████▋     | 79.2M/170M [00:01<00:01, 49.8MB/s]

 49%|████▉     | 84.2M/170M [00:01<00:01, 49.6MB/s]

 52%|█████▏    | 89.2M/170M [00:02<00:01, 49.7MB/s]

 55%|█████▌    | 94.2M/170M [00:02<00:01, 49.7MB/s]

 58%|█████▊    | 99.2M/170M [00:02<00:01, 49.6MB/s]

 61%|██████    | 104M/170M [00:02<00:01, 49.7MB/s] 

 64%|██████▍   | 109M/170M [00:02<00:01, 49.6MB/s]

 67%|██████▋   | 114M/170M [00:02<00:01, 49.8MB/s]

 70%|██████▉   | 119M/170M [00:02<00:01, 49.6MB/s]

 73%|███████▎  | 124M/170M [00:02<00:00, 49.5MB/s]

 76%|███████▌  | 129M/170M [00:02<00:00, 49.4MB/s]

 79%|███████▊  | 134M/170M [00:02<00:00, 49.6MB/s]

 82%|████████▏ | 139M/170M [00:03<00:00, 49.5MB/s]

 85%|████████▍ | 144M/170M [00:03<00:00, 49.7MB/s]

 88%|████████▊ | 149M/170M [00:03<00:00, 49.8MB/s]

 90%|█████████ | 154M/170M [00:03<00:00, 49.7MB/s]

 93%|█████████▎| 159M/170M [00:03<00:00, 49.8MB/s]

 96%|█████████▋| 164M/170M [00:03<00:00, 49.5MB/s]

 99%|█████████▉| 169M/170M [00:03<00:00, 49.7MB/s]

100%|██████████| 170M/170M [00:03<00:00, 46.4MB/s]




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | Sequential | 855 K  | train
---------------------------------------------
855 K     Trainable params
0         Non-trainable params
855 K     Total params
3.420     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


## Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning
movement, you can do so in the following ways!

### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool
tools we're building.

### Join our [Discord](https://discord.com/invite/tfXFetEZxv)!
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself
and share your interests in `#general` channel


### Contributions !
The best way to contribute to our community is to become a code contributor! At any time you can go to
[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)
GitHub Issues page and filter for "good first issue".

* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* You can also contribute your own notebooks with useful examples !

### Great thanks from the entire Pytorch Lightning Team for your interest !

[Pytorch Lightning{height="60px" width="240px"}](https://pytorchlightning.ai)