Apply torch.inverse() Function of PyTorch to Every Sample in the Batch
Last Updated :
28 Apr, 2025
PyTorch is a deep learning framework that provides a variety of functions to perform different operations on tensors. One such function is torch.inverse(), which can be used to compute the inverse of a square matrix.
Sometimes we may have a batch of matrices, where each matrix represents some data that we want to process using deep learning. In such cases, we may want to apply the torch.inverse() function to each matrix in the batch. We can use PyTorch's broadcasting feature which provides a facility to apply the same operation to all the elements in a tensor. It creates a new tensor very similar to the input tensor. The difference is that each element in the new tensor is the inverse of the corresponding element in the input tensor.
The following code demonstrates how to apply the torch.inverse() function to every sample in a batch. We first create a batch of matrices and then use the torch.inverse() function to find the inverse of each matrix in the batch.
Syntax of torch.inverse():
It takes the inverse of the square matrix input. input can be batches of 2D square tensors, in which case this function would return a tensor composed of individual inverses.
Syntax: torch.inverse(input, *, out=None)
Parameters:
- input (Tensor) – the input tensor of size (∗,n,n) where * is zero or more batch dimensions
- Keyword Arguments
- out (Tensor, optional) – the output tensor.
Example 1:
Suppose we have a batch of 2 matrices, where each matrix has a shape (3, 3). We can create this batch using the torch.randn() function.We can then apply the torch.inverse() function to the entire input tensor, which computes the inverse of each 3x3 matrix in the batch. The resulting output tensor also has shape (2, 3, 3), where each 3x3 matrix is the inverse of the corresponding matrix in the input tensor.
Python3
import torch
# Create a batch of 2 matrices with shape (2, 3, 3)
batch_size = 2
input_tensor = torch.randn(batch_size, 3, 3)
# Compute the inverse of each matrix in the batch
output_tensor = torch.inverse(input_tensor)
# Print the input and output tensors
print("Input tensor:")
print(input_tensor)
print("Output tensor:")
print(output_tensor)
Output:
Input tensor:
tensor([[[-0.9808, -1.5437, 1.1773],
[-0.8945, -1.2584, 1.6299],
[ 0.8855, 0.3088, -1.4001]],
[[ 0.4860, -0.8735, -1.1052],
[-0.4737, -2.8350, 0.1861],
[ 1.7559, -0.4935, 0.7353]]])
Output tensor:
tensor([[[-2.3209, 3.3154, 1.9079],
[-0.3517, -0.6101, -1.0059],
[-1.5453, 1.9621, 0.2705]],
[[ 0.2723, -0.1623, 0.4503],
[-0.0923, -0.3140, -0.0592],
[-0.7122, 0.1768, 0.2448]]])
Example 2:
In this example, we use the torch.randn() function to generate a set of three matrices with the shape (3, 2, 2). The torch.ones() function is then used to produce a tensor of ones with the shape (batch size, 1, 1). Using element-wise multiplication, we can utilize this tensor of ones to apply the torch.inverse() function to each matrix in the batch. Every 2x2 matrix in the resulting output tensor is the inverse of its corresponding matrix in the input tensor, and it also has a shape (3, 2, 2).
Python3
import torch
# Create a batch of 3 matrices with shape (3, 2, 2)
batch_size = 3
input_tensor = torch.randn(batch_size, 2, 2)
# Create a tensor of ones with shape (batch_size, 1, 1)
ones = torch.ones(batch_size, 1, 1)
# Compute the inverse of each matrix in the batch
output_tensor = input_tensor.inverse() * ones
# Print the input and output tensors
print("Input tensor:")
print(input_tensor)
print("Output tensor:")
print(output_tensor)
Output:
Input tensor:
tensor([[[-0.1727, -0.5076],
[ 0.9635, 0.0972]],
[[ 1.7375, 1.6074],
[ 0.0697, -0.8704]],
[[-0.6624, 1.8799],
[ 1.1704, -0.1165]]])
Output tensor:
tensor([[[ 0.2057, 1.0747],
[-2.0400, -0.3656]],
[[ 0.5358, 0.9895],
[ 0.0429, -1.0697]],
[[ 0.0549, 0.8855],
[ 0.5513, 0.3120]]])
Similar Reads
Python Tutorial | Learn Python Programming Language Python Tutorial â Python is one of the most popular programming languages. Itâs simple to use, packed with features and supported by a wide range of libraries and frameworks. Its clean syntax makes it beginner-friendly.Python is:A high-level language, used in web development, data science, automatio
10 min read
Machine Learning Tutorial Machine learning is a branch of Artificial Intelligence that focuses on developing models and algorithms that let computers learn from data without being explicitly programmed for every task. In simple words, ML teaches the systems to think and understand like humans by learning from the data.It can
5 min read
Non-linear Components In electrical circuits, Non-linear Components are electronic devices that need an external power source to operate actively. Non-Linear Components are those that are changed with respect to the voltage and current. Elements that do not follow ohm's law are called Non-linear Components. Non-linear Co
11 min read
Linear Regression in Machine learning Linear regression is a type of supervised machine-learning algorithm that learns from the labelled datasets and maps the data points with most optimized linear functions which can be used for prediction on new datasets. It assumes that there is a linear relationship between the input and output, mea
15+ min read
Support Vector Machine (SVM) Algorithm Support Vector Machine (SVM) is a supervised machine learning algorithm used for classification and regression tasks. It tries to find the best boundary known as hyperplane that separates different classes in the data. It is useful when you want to do binary classification like spam vs. not spam or
9 min read
Spring Boot Tutorial Spring Boot is a Java framework that makes it easier to create and run Java applications. It simplifies the configuration and setup process, allowing developers to focus more on writing code for their applications. This Spring Boot Tutorial is a comprehensive guide that covers both basic and advance
10 min read
Class Diagram | Unified Modeling Language (UML) A UML class diagram is a visual tool that represents the structure of a system by showing its classes, attributes, methods, and the relationships between them. It helps everyone involved in a projectâlike developers and designersâunderstand how the system is organized and how its components interact
12 min read
Logistic Regression in Machine Learning Logistic Regression is a supervised machine learning algorithm used for classification problems. Unlike linear regression which predicts continuous values it predicts the probability that an input belongs to a specific class. It is used for binary classification where the output can be one of two po
11 min read
K means Clustering â Introduction K-Means Clustering is an Unsupervised Machine Learning algorithm which groups unlabeled dataset into different clusters. It is used to organize data into groups based on their similarity. Understanding K-means ClusteringFor example online store uses K-Means to group customers based on purchase frequ
4 min read
K-Nearest Neighbor(KNN) Algorithm K-Nearest Neighbors (KNN) is a supervised machine learning algorithm generally used for classification but can also be used for regression tasks. It works by finding the "k" closest data points (neighbors) to a given input and makesa predictions based on the majority class (for classification) or th
8 min read