Techno Blender
Digitally Yours.

How to Save and Load Your Neural Networks in Python | by Leonie Monigatti | Apr, 2023

0 35


How to save and load neural networks in PyTorch and Tensorflow/Keras
How to save and load neural networks in PyTorch and Tensorflow/Keras (Image drawn by the author)

Training a neural network often takes a lot of time and computational resources. It would be a shame to lose a model after putting in that time and computation.

That’s why you should be able to save and load a Deep Learning model in its different stages (during training or after completing training) depending on your use case:

This article covers how to save and load checkpoints and entire models for the two main Deep Learning frameworks:

import torch
import tensorflow as tf
from tensorflow import keras

In general, when loading a saved model, it is important to ensure that the version of the framework you are using matches the version used to save the model.

print('PyTorch version:', torch.__version__)

print('TensorFlow version:', tf.__version__)
print('Keras version:', keras.__version__)

This is particularly important when moving models between different machines or environments. Thus, when saving and versioning models, it is important to store the framework’s version as metadata.

Checkpointing is useful to save a model at specified times during training. This is similar to saving your progress in a video game. It ensures you don’t have to start from the beginning and can resume from a checkpoint if anything goes wrong.

Sport illustrations by Storyset

Save and load model checkpoints in PyTorch

PyTorch checkpoints consist of the following components [2]:

  • Model state (weights and biases)
  • Optimizer state
  • Training step or epoch
  • Any additional information you choose to save (e.g., training configuration such as optimizer, metric, or current training loss)

PyTorch models are usually saved in the PyTorch binary format (.pt or .pth). While there is no difference between the two file extensions, the developer community [3] recommends favoring the .pt file extension over the .pth file extension because the latter collides with the file extension of Python path configuration files.

You can save training checkpoints of your model (checkpoint_1.pt, checkpoint_2.pt, …) in PyTorch with the following code snippet:

# Define your model
model = ...
optimizer = ...
criterion = ...

# Train the model
for epoch in range(num_epochs):
# Train the model for one epoch
...

# Save a checkpoint after each epoch
PATH = f'checkpoint_{epoch}.pt'
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
},
PATH)

You can load a training checkpoint of your model (e.g., checkpoint_3.pt) in PyTorch with the following code snippet. Make sure to:

  • Set model.train() before continuing model training
  • Continue training only for the remaining epochs (for epoch in range(epoch+1, num_epochs)).
# Define your model
model = ...
optimizer = ...
criterion = ...

# Load a saved checkpoint
checkpoint = torch.load('checkpoint_3.pt')
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Set dropout and batch normalization layers to train mode
model.train()

# Resume training the model for the remaining epochs
for epoch in range(epoch + 1, num_epochs):
...

Save and load model checkpoints in TensorFlow/Keras

In contrast to PyTorch, a checkpoint in TensorFlow/Keras only saves the model state (weights and biases) in a checkpoint file (.ckpt) [6].

You can save training checkpoints of your model (checkpoint_1.ckpt, checkpoint_2.ckpt, …) in TensorFlow/Keras by using a callback function as shown below:

# Define and compile your model
model = ...
...
model.compile(...)

# Define the checkpoint callback that saves the model's weights at every epoch
PATH = f'checkpoint_{epoch}.ckpt'

cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = PATH,
save_weights_only = True, # If False, saves the full model
save_freq = 'epoch')

# Train the model with the checkpoint callback
model.fit(X_train,
y_train,
epochs = num_epochs,
validation_data = (X_val, y_val),
callbacks = [cp_callback])

If you don’t want to use the callback function, you can also use the model.save_weights(PATH) method to save the model weights.

You can load a training checkpoint of your model (e.g., checkpoint_3.pt) in TensorFlow/Keras with the following code snippet. Make sure to continue training only for the remaining epochs (num_epochs — epoch).

# Define and compile your model
model = ...
...
model.compile(...)

# Load a saved checkpoint
epoch = 3
model = model.load_weights(PATH = f'checkpoint_{epoch}.ckpt')

# Define the checkpoint callback that saves the model's weights at every epoch
...

# Resume training the model
model.fit(X_train,
y_train,
epochs = (num_epochs-epoch),
validation_data = (X_val, y_val),
callbacks = [cp_callback])

You can also save a model when it is finished training. This is useful when you want to deploy your model or if the inference is done anywhere other than your training code.

Sport illustrations by Storyset

Save and load an entire model in PyTorch

In contrast to a checkpoint, a PyTorch only saves the model state (weights and biases) after the model is finished training [2].

PyTorch models are also saved in the PyTorch binary format (.pt preferred over .pth [3]).

You can save the trained model (model.pt) in PyTorch with the following code snippet. Make sure to save the model.state_dict() instead of model alone (see the alternative below).

PATH = "model.pt"

# Define your model
model = ...

# Train the model
...

# Save the model
torch.save(model.state_dict(), PATH)

You can load a trained model (model.pt) in PyTorch with the following code snippet. Make sure to:

  • Create an instance of the same model before loading the weights
  • Set model.eval() before using the model for inference
# Define your model architecture
model = ...

# Load the saved model parameters into your model
model.load_state_dict(torch.load(PATH))

# Set dropout and batch normalization layers to evaluation mode before running inference
model.eval()

# Use the model for inference
# ...

Alternatively, you could pickle the entire model as follows:

PATH = "model.pt"

# Define your model
model = ...

# Train the model
...

# Save the model
torch.save(model, PATH)

With this approach, you don’t have to define the model before loading its weights.

# Load the saved model parameters into your model
model = torch.load(PATH)

# Set dropout and batch normalization layers to evaluation mode before running inference
model.eval()

# Use the model for inference
# ...

However, this approach is not recommended: This method does not save the model class itself. Instead, it saves a path to the file containing the class. Thus, saving the model with this method can cause issues when you want to repurpose the model in other projects or with a different source code.

Save and load an entire model in TensorFlow/Keras

A Keras model consists of the following components [5, 6]:

  • Model architecture including optimizer and its state, losses and metrics in saved_model.pb
  • State of the model (weights and biases) in the variables/ directory.
  • Model’s compilation information

The model can be saved in the following file formats [5]:

  • TensorFlow SavedModel format — recommended and default format when no other file extension is given.
  • HDF5 format (.h5)— an older and light-weight alternative, does not save external losses and metrics

You can save the trained model (model) in TensorFlow/Keras with the following code snippet:

PATH = "model" # Will save the model in TensorFlow SavedModel format

# Define and compile your model
model = ...
...

# Train the model
...

# Save the model
model.save(PATH)

You can load a trained model (model) in TensorFlow/Keras with the following code snippet.

# Load the saved model
model = keras.models.load_model(PATH)

# Use the model for inference
# ...

Best checkpoint picking is a technique in Deep Learning which monitors the validation metric during training (without early stopping) and uses the checkpoint with the best validation metric for inference.

In a recent article on intermediate Deep Learning techniques, we have reviewed that currently there seems to be no common understanding of the best practices regarding best checkpoint picking. While the Deep Learning Tuning Playbook [1] recommends using best checkpoint picking, Kaggle Grandmasters don’t recommend it because this technique tends to overfit the model to the validation set [4].

Nonetheless, we will cover how to apply best checkpoint picking to your Deep Learning pipeline.

Best checkpoint picking in PyTorch

Saving is similar to saving a model checkpoint in PyTorch but with some alterations:

  • We save only the model but no training information like the epoch, optimizer state, etc. because we don’t intend to continue training with this model
  • Manually add monitoring of the validation metric during training
# Define your model
model = ...
optimizer = ...
criterion = ...

# Train the model
for epoch in range(num_epochs):
# Train the model for one epoch
...

if (best_metric < current_metric):
best_metric = current_metric

# Save a checkpoint after each epoch
PATH = f'checkpoint_{epoch}.pt'

# Save the model
torch.save(model.state_dict(), PATH)

Loading is the same as loading an entire model in PyTorch.

Best checkpoint picking in TensorFlow/Keras

Saving is similar to saving a model checkpoint in TensorFlow/Keras but with some alterations:

  • Use save_weights_only = True to save the entire model
  • Remove .ckpt from the PATH to save the model in SavedModel format
  • Add monitor and save_best_only parameters
# Define and compile your model
model = ...
...
model.compile(...)

# Define the checkpoint callback that saves the model's weights at every epoch
PATH = f'./checkpoints/checkpoint_{epoch}' # Remove .ckpt to save as SavedModel format
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = PATH,
monitor = "val_acc", # Metric to monitor for best checkpoint picking
save_best_only = True,
save_weights_only = True,
save_freq = 'epoch')

# Train the model with the checkpoint callback
model.fit(X_train,
y_train,
epochs = num_epochs,
validation_data = (X_val, y_val),
callbacks = [cp_callback])

Loading is the same as loading an entire model in TensorFlow/Keras

This article reviewed different use cases for saving and loading neural networks with the Deep Learning frameworks PyTorch and TensorFlow/Keras. Below you can see an overview of the comparison between PyTorch and Keras.

Overview of what is saved when saving a Deep Learning checkpoint or model in PyTorch vs. TensorFlow/Keras
Overview of what is saved when saving a Deep Learning checkpoint or model in PyTorch vs. TensorFlow/Keras (Image by the author)
  • Model architecture: In PyTorch, the model architecture is never saved and thus has to be saved with some source code version control in addition. In TensorFlow/Keras the model architecture is saved when you save the entire model.
  • Model weights: Both PyTorch and TensorFlow/Keras can save the model weights only. However, in PyTorch, this is done when you save the final trained model, while in TensorFlow/Keras this applies to checkpointing.

Subscribe for free to get notified when I publish a new story.

Become a Medium member to read more stories from other writers and me. You can support me by using my referral link when you sign up. I’ll receive a commission at no extra cost to you.

Find me on LinkedIn, Twitter, and Kaggle!

[1] V. Godbole, G. E. Dahl, J. Gilmer, C. J. Shallue and Z. Nado (2023). Deep Learning Tuning Playbook (Version 1.0) (accessed February 3rd, 2023)

[2] M. Inkawhich for PyTorch (2023). Saving and Loading Models (accessed March 27th, 2023).

[3] kmario23 in Stackoverflow (2019). What is the difference between .pt, .pth and .pwf extentions in PyTorch? (accessed March 27th, 2023).

[4] P. Singer and Y. Babakhin (2022). Practical Tips for Deep Transfer Learning presented at Kaggle Days Paris in November 2022.

[5] TensorFlow (2023). Guide: Save and load Keras models (accessed March 27th, 2023).

[6] TensorFlow (2023). Tutorials: Save and load models (accessed March 27th, 2023).




How to save and load neural networks in PyTorch and Tensorflow/Keras
How to save and load neural networks in PyTorch and Tensorflow/Keras (Image drawn by the author)

Training a neural network often takes a lot of time and computational resources. It would be a shame to lose a model after putting in that time and computation.

That’s why you should be able to save and load a Deep Learning model in its different stages (during training or after completing training) depending on your use case:

This article covers how to save and load checkpoints and entire models for the two main Deep Learning frameworks:

import torch
import tensorflow as tf
from tensorflow import keras

In general, when loading a saved model, it is important to ensure that the version of the framework you are using matches the version used to save the model.

print('PyTorch version:', torch.__version__)

print('TensorFlow version:', tf.__version__)
print('Keras version:', keras.__version__)

This is particularly important when moving models between different machines or environments. Thus, when saving and versioning models, it is important to store the framework’s version as metadata.

Checkpointing is useful to save a model at specified times during training. This is similar to saving your progress in a video game. It ensures you don’t have to start from the beginning and can resume from a checkpoint if anything goes wrong.

Sport illustrations by Storyset

Save and load model checkpoints in PyTorch

PyTorch checkpoints consist of the following components [2]:

  • Model state (weights and biases)
  • Optimizer state
  • Training step or epoch
  • Any additional information you choose to save (e.g., training configuration such as optimizer, metric, or current training loss)

PyTorch models are usually saved in the PyTorch binary format (.pt or .pth). While there is no difference between the two file extensions, the developer community [3] recommends favoring the .pt file extension over the .pth file extension because the latter collides with the file extension of Python path configuration files.

You can save training checkpoints of your model (checkpoint_1.pt, checkpoint_2.pt, …) in PyTorch with the following code snippet:

# Define your model
model = ...
optimizer = ...
criterion = ...

# Train the model
for epoch in range(num_epochs):
# Train the model for one epoch
...

# Save a checkpoint after each epoch
PATH = f'checkpoint_{epoch}.pt'
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
},
PATH)

You can load a training checkpoint of your model (e.g., checkpoint_3.pt) in PyTorch with the following code snippet. Make sure to:

  • Set model.train() before continuing model training
  • Continue training only for the remaining epochs (for epoch in range(epoch+1, num_epochs)).
# Define your model
model = ...
optimizer = ...
criterion = ...

# Load a saved checkpoint
checkpoint = torch.load('checkpoint_3.pt')
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Set dropout and batch normalization layers to train mode
model.train()

# Resume training the model for the remaining epochs
for epoch in range(epoch + 1, num_epochs):
...

Save and load model checkpoints in TensorFlow/Keras

In contrast to PyTorch, a checkpoint in TensorFlow/Keras only saves the model state (weights and biases) in a checkpoint file (.ckpt) [6].

You can save training checkpoints of your model (checkpoint_1.ckpt, checkpoint_2.ckpt, …) in TensorFlow/Keras by using a callback function as shown below:

# Define and compile your model
model = ...
...
model.compile(...)

# Define the checkpoint callback that saves the model's weights at every epoch
PATH = f'checkpoint_{epoch}.ckpt'

cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = PATH,
save_weights_only = True, # If False, saves the full model
save_freq = 'epoch')

# Train the model with the checkpoint callback
model.fit(X_train,
y_train,
epochs = num_epochs,
validation_data = (X_val, y_val),
callbacks = [cp_callback])

If you don’t want to use the callback function, you can also use the model.save_weights(PATH) method to save the model weights.

You can load a training checkpoint of your model (e.g., checkpoint_3.pt) in TensorFlow/Keras with the following code snippet. Make sure to continue training only for the remaining epochs (num_epochs — epoch).

# Define and compile your model
model = ...
...
model.compile(...)

# Load a saved checkpoint
epoch = 3
model = model.load_weights(PATH = f'checkpoint_{epoch}.ckpt')

# Define the checkpoint callback that saves the model's weights at every epoch
...

# Resume training the model
model.fit(X_train,
y_train,
epochs = (num_epochs-epoch),
validation_data = (X_val, y_val),
callbacks = [cp_callback])

You can also save a model when it is finished training. This is useful when you want to deploy your model or if the inference is done anywhere other than your training code.

Sport illustrations by Storyset

Save and load an entire model in PyTorch

In contrast to a checkpoint, a PyTorch only saves the model state (weights and biases) after the model is finished training [2].

PyTorch models are also saved in the PyTorch binary format (.pt preferred over .pth [3]).

You can save the trained model (model.pt) in PyTorch with the following code snippet. Make sure to save the model.state_dict() instead of model alone (see the alternative below).

PATH = "model.pt"

# Define your model
model = ...

# Train the model
...

# Save the model
torch.save(model.state_dict(), PATH)

You can load a trained model (model.pt) in PyTorch with the following code snippet. Make sure to:

  • Create an instance of the same model before loading the weights
  • Set model.eval() before using the model for inference
# Define your model architecture
model = ...

# Load the saved model parameters into your model
model.load_state_dict(torch.load(PATH))

# Set dropout and batch normalization layers to evaluation mode before running inference
model.eval()

# Use the model for inference
# ...

Alternatively, you could pickle the entire model as follows:

PATH = "model.pt"

# Define your model
model = ...

# Train the model
...

# Save the model
torch.save(model, PATH)

With this approach, you don’t have to define the model before loading its weights.

# Load the saved model parameters into your model
model = torch.load(PATH)

# Set dropout and batch normalization layers to evaluation mode before running inference
model.eval()

# Use the model for inference
# ...

However, this approach is not recommended: This method does not save the model class itself. Instead, it saves a path to the file containing the class. Thus, saving the model with this method can cause issues when you want to repurpose the model in other projects or with a different source code.

Save and load an entire model in TensorFlow/Keras

A Keras model consists of the following components [5, 6]:

  • Model architecture including optimizer and its state, losses and metrics in saved_model.pb
  • State of the model (weights and biases) in the variables/ directory.
  • Model’s compilation information

The model can be saved in the following file formats [5]:

  • TensorFlow SavedModel format — recommended and default format when no other file extension is given.
  • HDF5 format (.h5)— an older and light-weight alternative, does not save external losses and metrics

You can save the trained model (model) in TensorFlow/Keras with the following code snippet:

PATH = "model" # Will save the model in TensorFlow SavedModel format

# Define and compile your model
model = ...
...

# Train the model
...

# Save the model
model.save(PATH)

You can load a trained model (model) in TensorFlow/Keras with the following code snippet.

# Load the saved model
model = keras.models.load_model(PATH)

# Use the model for inference
# ...

Best checkpoint picking is a technique in Deep Learning which monitors the validation metric during training (without early stopping) and uses the checkpoint with the best validation metric for inference.

In a recent article on intermediate Deep Learning techniques, we have reviewed that currently there seems to be no common understanding of the best practices regarding best checkpoint picking. While the Deep Learning Tuning Playbook [1] recommends using best checkpoint picking, Kaggle Grandmasters don’t recommend it because this technique tends to overfit the model to the validation set [4].

Nonetheless, we will cover how to apply best checkpoint picking to your Deep Learning pipeline.

Best checkpoint picking in PyTorch

Saving is similar to saving a model checkpoint in PyTorch but with some alterations:

  • We save only the model but no training information like the epoch, optimizer state, etc. because we don’t intend to continue training with this model
  • Manually add monitoring of the validation metric during training
# Define your model
model = ...
optimizer = ...
criterion = ...

# Train the model
for epoch in range(num_epochs):
# Train the model for one epoch
...

if (best_metric < current_metric):
best_metric = current_metric

# Save a checkpoint after each epoch
PATH = f'checkpoint_{epoch}.pt'

# Save the model
torch.save(model.state_dict(), PATH)

Loading is the same as loading an entire model in PyTorch.

Best checkpoint picking in TensorFlow/Keras

Saving is similar to saving a model checkpoint in TensorFlow/Keras but with some alterations:

  • Use save_weights_only = True to save the entire model
  • Remove .ckpt from the PATH to save the model in SavedModel format
  • Add monitor and save_best_only parameters
# Define and compile your model
model = ...
...
model.compile(...)

# Define the checkpoint callback that saves the model's weights at every epoch
PATH = f'./checkpoints/checkpoint_{epoch}' # Remove .ckpt to save as SavedModel format
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = PATH,
monitor = "val_acc", # Metric to monitor for best checkpoint picking
save_best_only = True,
save_weights_only = True,
save_freq = 'epoch')

# Train the model with the checkpoint callback
model.fit(X_train,
y_train,
epochs = num_epochs,
validation_data = (X_val, y_val),
callbacks = [cp_callback])

Loading is the same as loading an entire model in TensorFlow/Keras

This article reviewed different use cases for saving and loading neural networks with the Deep Learning frameworks PyTorch and TensorFlow/Keras. Below you can see an overview of the comparison between PyTorch and Keras.

Overview of what is saved when saving a Deep Learning checkpoint or model in PyTorch vs. TensorFlow/Keras
Overview of what is saved when saving a Deep Learning checkpoint or model in PyTorch vs. TensorFlow/Keras (Image by the author)
  • Model architecture: In PyTorch, the model architecture is never saved and thus has to be saved with some source code version control in addition. In TensorFlow/Keras the model architecture is saved when you save the entire model.
  • Model weights: Both PyTorch and TensorFlow/Keras can save the model weights only. However, in PyTorch, this is done when you save the final trained model, while in TensorFlow/Keras this applies to checkpointing.

Subscribe for free to get notified when I publish a new story.

Become a Medium member to read more stories from other writers and me. You can support me by using my referral link when you sign up. I’ll receive a commission at no extra cost to you.

Find me on LinkedIn, Twitter, and Kaggle!

[1] V. Godbole, G. E. Dahl, J. Gilmer, C. J. Shallue and Z. Nado (2023). Deep Learning Tuning Playbook (Version 1.0) (accessed February 3rd, 2023)

[2] M. Inkawhich for PyTorch (2023). Saving and Loading Models (accessed March 27th, 2023).

[3] kmario23 in Stackoverflow (2019). What is the difference between .pt, .pth and .pwf extentions in PyTorch? (accessed March 27th, 2023).

[4] P. Singer and Y. Babakhin (2022). Practical Tips for Deep Transfer Learning presented at Kaggle Days Paris in November 2022.

[5] TensorFlow (2023). Guide: Save and load Keras models (accessed March 27th, 2023).

[6] TensorFlow (2023). Tutorials: Save and load models (accessed March 27th, 2023).

FOLLOW US ON GOOGLE NEWS

Read original article here

Denial of responsibility! Techno Blender is an automatic aggregator of the all world’s media. In each content, the hyperlink to the primary source is specified. All trademarks belong to their rightful owners, all materials to their authors. If you are the owner of the content and do not want us to publish your materials, please contact us by email – [email protected]. The content will be deleted within 24 hours.

Leave a comment