Save and load models in Tensorflow
Last Updated :
08 Apr, 2025
Training machine learning or deep learning model is time-consuming and shutting down the notebook causes all the weights and activations to disappear as the memory is flushed. Hence, we save models for reusability, collaboration, and continuation of training.
- Saving the model allows us to avoid lengthy training periods and enables others to replicate the model.
- It also allows you to share the model with others so they can replicate your results.
- When sharing machine learning models it’s common to include the following:
- code to create the model
- trained weights for the model
Below are the methods for saving and loading machine learning models in TensorFlow.

Methods to Save and Load Models
Here are the methods that can be used to save model.
1. Using the save()
Method
The save()
method allows you to save the complete model including:
- Model architecture
- Model weights
- Model optimizer state to resume training from where you left off.
tensorflow.keras.X.save('location/model_name')
Where X
can be a Sequential, Functional Model or Model subclass. The location
specifies where the model is stored and if no path is specified it will be saved in the same location as the Python file.
To load the model use theload_model()
method:
tensorflow.keras.models.load_model('location/model_name')
2. Using the save_weights()
Method
In some cases you might want to save just the weights of the model instead of the entire model. This can be done using the save_weights()
method which saves the weights of all the layers in the model.
tensorflow.keras.Model.save_weights('location/weights_name')
The weights_name
is the file name for the saved weights and if no path is provided it is saved in the same location as the Python file.
To load the saved weights use the load_weights
()
method:
tensorflow.keras.Model.load_weights('location/weights_name')
Note: When loading weights ensure that the model's architecture is the same as the one used to save the weights. For example you cannot load the weights of a model with two dense layers into a model with just one dense layer.
3. HDF5 Format (.h5
)
If you save your model with the .h5
extension the model is saved in HDF5 format. This format is portable and commonly used for storing large data and models. You can specify the .h5
extension when saving the model and TensorFlow will automatically save the model in this format.
model.save('my_model.h5')
If you don’t specify the extension, TensorFlow saves the model in its native format.
Code to Save and Load Models
Here we will build a neural network and then save it.
1. Import Necessary Module
We will importing tenserflow for model making.
Input, Conv2D, Dense, Flatten, Dropout, MaxPooling2D
and
BatchNormalization
is imported
to build neural networks.Model
to define the model architecture.load_model
to load saved models.
Python
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Model, load_model
2. Load and Preprocess Data
- Here we will use CIFAR-10 dataset which contains 60,000 images with 32x32 size in 10 classes and 50,000 images will be used for training and 10,000 for testing.
Python
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = y_train.flatten(), y_test.flatten()
3. Defining the Model
The model contains following layers:
Python
K = len(set(y_train))
i = Input(shape=x_train[0].shape)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(i)
x = BatchNormalization()(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2))(x)
x = Flatten()(x)
x = Dropout(0.2)(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.2)(x)
x = Dense(K, activation='softmax')(x)
model = Model(i, x)
model.summary()
Output:
Model Summary4. Saving and Loading the Model
Python
model.save('my_model.h5')
print("Model saved!")
saved_model = load_model('my_model.h5')
if saved_model is not None:
print("Model loaded successfully!")
else:
print("Failed to load the model.")
Output:
Model saved!
Model loaded successfully!
Saving and loading models is essential for efficient machine learning workflows, enabling you to resume training without starting from scratch and share models with others.
Similar Reads
Save and Load Models using TensorFlow in Json? If you are looking to explore Machine Learning with TensorFlow, you are at the right place. This comprehensive article explains how to save and load the models in TensorFlow along with its brief overview. If you read this article till the end, you will not need to look for further guides on how to s
6 min read
Save and Load Models in PyTorch It often happens that we need to use the already-trained models to perform some operations in our development environment. In this case, would you create the model again and again? Or, you will save the model somewhere else and load it as per the requirement. You would definitely choose the second o
10 min read
tf.keras.models.load_model in Tensorflow TensorFlow is an open-source machine-learning library developed by Google. In this article, we are going to explore the how can we load a model in TensorFlow. tf.keras.models.load_model tf.keras.models.load_model function is used to load saved models from storage for further use. It allows users to
3 min read
Serving a TensorFlow Model TensorFlow Serving stands as a versatile and high-performance system tailored for serving machine learning models in production settings. Its primary objective is to simplify the deployment of novel algorithms and experiments while maintaining consistent server architecture and APIs. While it seamle
11 min read
Load NumPy data in Tensorflow In this article, we will be looking at the approach to load Numpy data in Tensorflow in the Python programming language. Using tf.data.Dataset.from_tensor_slices() function Under this approach, we are loading a Numpy array with the use of tf.data.Dataset.from_tensor_slices() method, we can get the s
2 min read
Export a SavedModel in Tensorflow In TensorFlow, a SavedModel is basically a serialized format for storing a complete TensorFlow program. The tf.saved_model.save() function in TensorFlow can be used to export a SavedModel. A trained model and its related variables are saved to disc in the SavedModel format by this function. It inclu
4 min read
Load Images in Tensorflow - Python In this article, we are going to see how to load images in TensorFlow in Python. Loading Images in Tensorflow For loading Images Using Tenserflow, we use tf.keras.utils.load_img function, which loads the image from a particular provided path in PIL Format. PIL is a Python Imaging Library that gives
3 min read
Saving and Loading XGBoost Models XGBoost is a powerful and widely-used gradient boosting library that has become a staple in machine learning. Its ability to handle large datasets and provide accurate results makes it a popular choice among data scientists. However, one crucial aspect of working with XGBoost models is saving and lo
7 min read
Load CSV data in Tensorflow This article will look at the ways to load CSV data in the Python programming language using TensorFlow. TensorFlow library provides the make_csv_dataset( ) function, which is used to read the data and use it in our programs. Loading single CSV File To get the single CSV data file from the URL, we
2 min read
Tensorflow.js tf.LayersModel class .save() Method Tensorflow.js is an open-source library that is developed by Google for running machine learning models as well as deep learning neural networks in the browser or node environment. The .save() function is used to save the structure and/or the weights of the stated LayersModel. Note: An IOHandler is
2 min read