Batch Normalization Implementation in PyTorch
Last Updated :
23 Jul, 2025
Batch Normalization (BN) is a critical technique in the training of neural networks, designed to address issues like vanishing or exploding gradients during training. In this tutorial, we will implement batch normalization using PyTorch framework.
What is Batch Normalization?
Gradients are used to update weights during training, that can become unstable or vanish entirely, hindering the network's ability to learn effectively. Batch Normalization (BN) is a powerful technique that addresses these issues by stabilizing the learning process and accelerating convergence. Batch Normalization(BN) is a popular technique used in deep learning to improve the training of neural networks by normalizing the inputs of each layer. Implementing batch normalization in PyTorch models requires understanding its concepts and best practices to achieve optimal performance.
Batch Normalization makes the training to be more consistent, and faster, adds better performance, and avoids problems like gradient becoming too small or too large during training and ensures that the network doesn't get stuck or make big mistakes while learning. It is helpful when neural network faces issues like slow training or unstable gradients.
How Batch Normalization works?
- During each training iteration (epoch), BN takes a mini batch of data and normalizes the activations (outputs) of a hidden layer. This normalization transforms the activations to have a mean of 0 and a standard deviation of 1.
- While normalization helps with stability, it can also disrupt the network's learned features. To compensate, BN introduces two learnable parameters: gamma and beta. Gamma rescales the normalized activations, and beta shifts them, allowing the network to recover the information present in the original activations.
It ensures that each element or component is in the right proportion before distributing the inputs into the layers and each layer is normalized before being passed to the next layer.
Correct Batch Size:
- Resonable sized mini-batches must be taken into consideration during training. It performs better with large batch sizes as it computes more accurate batch statistics.
- Leading it to be more stable gradients and faster convergence.
Implementing Batch Normalization in PyTorch
PyTorch provides the nn.BatchNormXd
module (where X is 1 for 1D data, 2 for 2D data like images, and 3 for 3D data) for convenient BN implementation. In this tutorial, we will see the implementation of batch normalizationa and it's effect on model. We will train the model and highlight the loss before and after using batch normalization with MNIST dataset widely used dataset in the field of machine learing and computer vision. This dataset consists of a collection of 28X28 pixel grayscale images of handwritten digits ranges from (0 to 9) inclusive along with their corresponding labels.
Prerequsite: Install the PyTorch library:
pip install torch torchvision
Step 1: Importing necessary libraries
- Torch : Imports the PyTorch library for deep learning operations.
- nn : Imports the neural network module from PyTorch for building neural network architectures.
- DataLoader : Import dataloader class from PyTorch, it helps in loading the datasets efficiently for traning and testing.
- Transforms : Imports the transforms module from torchvision, which provides common image transformations.
- Time : Imports the time module for time-related operations.
- OS : Imports the os module, which provides functions for interacting with the operating system.
Python3
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import time
import datetime
import os
Step 2: Implementing Batch Normalization to the model
In the code snippet, Batch Normalization (BN) is incorporated into the neural network architecture using the nn.BatchNorm1d
layer, the layers are added after the fully connected layers.
nn.BatchNorm1d(64)
is applied after the first fully connected layer (64 neurons).nn.BatchNorm1d(32)
is applied after the second fully connected layer (32 neurons).
The arguments (64
and 32
) represent the number of features (neurons) in the respective layers to which Batch Normalization is applied. Following Batch Normalization, the ReLU activation function is applied to introduce non-linearity. In the forward
method, the input tensor x
is passed through the layers, including those with Batch Normalization.
Python3
# Define your neural network architecture with batch normalization
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Flatten(), # Flatten the input image tensor
nn.Linear(28 * 28, 64), # Fully connected layer from 28*28 to 64 neurons
nn.BatchNorm1d(64), # Batch normalization for stability and faster convergence
nn.ReLU(), # ReLU activation function
nn.Linear(64, 32), # Fully connected layer from 64 to 32 neurons
nn.BatchNorm1d(32), # Batch normalization for stability and faster convergence
nn.ReLU(), # ReLU activation function
nn.Linear(32, 10) # Fully connected layer from 32 to 10 neurons (for MNIST classes)
)
def forward(self, x):
return self.layers(x)
Step 3: The next step follows loading and training the dataset with simple MLP neural network architecture for the MINST dataset and creating the dataloader for training.
Python3
if __name__ == '__main__':
# Set random seed for reproducibility
torch.manual_seed(47)
# Load the MNIST dataset
transform = transforms.Compose([
transforms.ToTensor()
])
train_data = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
Step 4: Initialize the MLP model, Define the loss function(CrossEntropyLoss), and optimizer (Adam).
Python3
mlp = MLP() # Initialize MLP model
loss_function = nn.CrossEntropyLoss() # Cross-entropy loss function for classification
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3) # Adam optimizer with learning rate 0.001
Step 5: Define Training Loop
We are training the model for 3 epoch using a training loop. It will itertate over mini-batches of traning data, computes the loss, performs backpropogation, and updatess the model paramaters.
Python3
start_time = time.time()
# Training loop
for epoch in range(3): # Iterate over 3 epochs
print(f'Starting epoch {epoch + 1}')
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad() # Zero the gradients
outputs = mlp(inputs.view(inputs.shape[0], -1)) # Flatten the input for MLP and forward pass
loss = loss_function(outputs, labels) # Compute the loss
loss.backward() # Backpropagation
optimizer.step() # Optimizer step to update parameters
running_loss += loss.item()
if i % 100 == 99: # Print every 100 mini-batches
print(f'Epoch {epoch + 1}, Mini-batch {i + 1}, Loss: {running_loss / 100}')
running_loss = 0.0
print('Training finished')
end_time = time.time() # Record end time
print('Training process has been completed. ')
training_time = end_time - start_time
print('Training time:', str(datetime.timedelta(seconds=training_time))) # for calculating the training time in minutes and seconds format
Output:
Starting epoch 1
Epoch 1, Mini-batch 100, Loss: 1.107109518647194
Epoch 1, Mini-batch 200, Loss: 0.48408970028162
Epoch 1, Mini-batch 300, Loss: 0.3104418055713177
Epoch 1, Mini-batch 400, Loss: 0.2633690595626831
Epoch 1, Mini-batch 500, Loss: 0.2228860107809305
Epoch 1, Mini-batch 600, Loss: 0.20098184436559677
Epoch 1, Mini-batch 700, Loss: 0.18423103891313075
Epoch 1, Mini-batch 800, Loss: 0.16403419613838197
Epoch 1, Mini-batch 900, Loss: 0.14670498583465816
Starting epoch 2
Epoch 2, Mini-batch 100, Loss: 0.1223447759822011
Epoch 2, Mini-batch 200, Loss: 0.11535881120711565
Epoch 2, Mini-batch 300, Loss: 0.12264159372076393
Epoch 2, Mini-batch 400, Loss: 0.1274782767519355
Epoch 2, Mini-batch 500, Loss: 0.12688526364043354
Epoch 2, Mini-batch 600, Loss: 0.10709397405385972
Epoch 2, Mini-batch 700, Loss: 0.12462730823084713
Epoch 2, Mini-batch 800, Loss: 0.10854666410945356
Epoch 2, Mini-batch 900, Loss: 0.10740736600011587
Starting epoch 3
Epoch 3, Mini-batch 100, Loss: 0.09494352690875531
Epoch 3, Mini-batch 200, Loss: 0.08548182763159275
Epoch 3, Mini-batch 300, Loss: 0.08944599309004843
Epoch 3, Mini-batch 400, Loss: 0.08315778982825578
Epoch 3, Mini-batch 500, Loss: 0.0855206391401589
Epoch 3, Mini-batch 600, Loss: 0.08882722020149231
Epoch 3, Mini-batch 700, Loss: 0.0896124207880348
Epoch 3, Mini-batch 800, Loss: 0.08545528341084718
Epoch 3, Mini-batch 900, Loss: 0.09168351721018553
Training finished
Training process has been completed.
Training time: 0:00:21.384532
Note: The loss after mini-batch 900 of epoch 3 with batch normalization is 0.09196628
Benefits of Batch Normalization
- Faster Convergence: By stabilizing the gradients, BN allows you to use higher learning rates, which can significantly speed up training.
- Reduced Internal Covariate Shift: As the network trains, the distribution of activations within a layer can change (internal covariate shift). BN helps mitigate this by normalizing activations before subsequent layers, making the training process less sensitive to these shifts.
- Initialization Insensitivity: BN makes the network less reliant on the initial weight values, allowing for more robust training and potentially better performance.
Conclusion
The choice between using batch normalization or not depends on factors such as model architecture, dataset characteristics, and computational resources. The discussed practices for batch normalization must be taken into consider as it reflects its output in the MLP. Thus for better generalization, and faster convergence leads to takeforward the technolgies in deeper networks.
Similar Reads
Deep Learning Tutorial Deep Learning is a subset of Artificial Intelligence (AI) that helps machines to learn from large datasets using multi-layered neural networks. It automatically finds patterns and makes predictions and eliminates the need for manual feature extraction. Deep Learning tutorial covers the basics to adv
5 min read
Deep Learning Basics
Introduction to Deep LearningDeep Learning is transforming the way machines understand, learn and interact with complex data. Deep learning mimics neural networks of the human brain, it enables computers to autonomously uncover patterns and make informed decisions from vast amounts of unstructured data. How Deep Learning Works?
7 min read
Artificial intelligence vs Machine Learning vs Deep LearningNowadays many misconceptions are there related to the words machine learning, deep learning, and artificial intelligence (AI), most people think all these things are the same whenever they hear the word AI, they directly relate that word to machine learning or vice versa, well yes, these things are
4 min read
Deep Learning Examples: Practical Applications in Real LifeDeep learning is a branch of artificial intelligence (AI) that uses algorithms inspired by how the human brain works. It helps computers learn from large amounts of data and make smart decisions. Deep learning is behind many technologies we use every day like voice assistants and medical tools.This
3 min read
Challenges in Deep LearningDeep learning, a branch of artificial intelligence, uses neural networks to analyze and learn from large datasets. It powers advancements in image recognition, natural language processing, and autonomous systems. Despite its impressive capabilities, deep learning is not without its challenges. It in
7 min read
Why Deep Learning is ImportantDeep learning has emerged as one of the most transformative technologies of our time, revolutionizing numerous fields from computer vision to natural language processing. Its significance extends far beyond just improving predictive accuracy; it has reshaped entire industries and opened up new possi
5 min read
Neural Networks Basics
What is a Neural Network?Neural networks are machine learning models that mimic the complex functions of the human brain. These models consist of interconnected nodes or neurons that process data, learn patterns and enable tasks such as pattern recognition and decision-making.In this article, we will explore the fundamental
11 min read
Types of Neural NetworksNeural networks are computational models that mimic the way biological neural networks in the human brain process information. They consist of layers of neurons that transform the input data into meaningful outputs through a series of mathematical operations. In this article, we are going to explore
7 min read
Layers in Artificial Neural Networks (ANN)In Artificial Neural Networks (ANNs), data flows from the input layer to the output layer through one or more hidden layers. Each layer consists of neurons that receive input, process it, and pass the output to the next layer. The layers work together to extract features, transform data, and make pr
4 min read
Activation functions in Neural NetworksWhile building a neural network, one key decision is selecting the Activation Function for both the hidden layer and the output layer. It is a mathematical function applied to the output of a neuron. It introduces non-linearity into the model, allowing the network to learn and represent complex patt
8 min read
Feedforward Neural NetworkFeedforward Neural Network (FNN) is a type of artificial neural network in which information flows in a single direction i.e from the input layer through hidden layers to the output layer without loops or feedback. It is mainly used for pattern recognition tasks like image and speech classification.
6 min read
Backpropagation in Neural NetworkBack Propagation is also known as "Backward Propagation of Errors" is a method used to train neural network . Its goal is to reduce the difference between the modelâs predicted output and the actual output by adjusting the weights and biases in the network.It works iteratively to adjust weights and
9 min read
Deep Learning Models
Deep Learning Frameworks
TensorFlow TutorialTensorFlow is an open-source machine-learning framework developed by Google. It is written in Python, making it accessible and easy to understand. It is designed to build and train machine learning (ML) and deep learning models. It is highly scalable for both research and production.It supports CPUs
2 min read
Keras TutorialKeras high-level neural networks APIs that provide easy and efficient design and training of deep learning models. It is built on top of powerful frameworks like TensorFlow, making it both highly flexible and accessible. Keras has a simple and user-friendly interface, making it ideal for both beginn
3 min read
PyTorch TutorialPyTorch is an open-source deep learning framework designed to simplify the process of building neural networks and machine learning models. With its dynamic computation graph, PyTorch allows developers to modify the networkâs behavior in real-time, making it an excellent choice for both beginners an
7 min read
Caffe : Deep Learning FrameworkCaffe (Convolutional Architecture for Fast Feature Embedding) is an open-source deep learning framework developed by the Berkeley Vision and Learning Center (BVLC) to assist developers in creating, training, testing, and deploying deep neural networks. It provides a valuable medium for enhancing com
8 min read
Apache MXNet: The Scalable and Flexible Deep Learning FrameworkIn the ever-evolving landscape of artificial intelligence and deep learning, selecting the right framework for building and deploying models is crucial for performance, scalability, and ease of development. Apache MXNet, an open-source deep learning framework, stands out by offering flexibility, sca
6 min read
Theano in PythonTheano is a Python library that allows us to evaluate mathematical operations including multi-dimensional arrays efficiently. It is mostly used in building Deep Learning Projects. Theano works way faster on the Graphics Processing Unit (GPU) rather than on the CPU. This article will help you to unde
4 min read
Model Evaluation
Deep Learning Projects