Techno Blender
Digitally Yours.

Using SHAP to Debug a PyTorch Image Regression Model | by Conor O’Sullivan | Jan, 2023

0 40


The first column is two images of a track for a mini-automated car. The other columns are the SHAP values from the prediction made using those images.
(source: author)

Autonomous cars terrify me. Big hunks of metal flying around with no humans to stop them if something goes wrong. To reduce this risk it is not enough to evaluate the models powering these beasts. We also need to understand how they are making predictions. This is to avoid any edge cases that would cause unforeseen accidents.

Okay, so our application is not so consequential. We will be debugging the model used to power a mini-automated car (the worst you could expect is a bruised ankle). Still, IML methods can be useful. We will see how they can even improve the performance of the model.

Specifically, we will:

  • Fine-tune ResNet-18 using PyTorch with image data and a continuous target variable
  • Evaluate the model using MSE and scatter plots
  • Interpret the model using DeepSHAP
  • Correct the model through better data collection
  • Discuss how image augmentation could further improve the model

Along the way, we will discuss some key pieces of Python code. You can also find the full project on GitHub.

If you are new to SHAP and want the basics, see the article below. Otherwise, check out this SHAP course. You can get free access if you sign up for my Newsletter 🙂

# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
import random

from PIL import Image
import cv2

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

import shap
from sklearn.metrics import mean_squared_error

We start the project by collecting data in one room only (this will come back to haunt us). As mentioned, we use images to power an automated car. You can find examples of these on Kaggle. These images are all 224 x 224 pixels.

We display one of them with the code below. Take note of the image name (line 2). The first two numbers are x and y coordinates within the 224 x 224 frame. In Figure 1, you can see we have displayed these coordinates using a green circle (line 8).

#Load example image
name = "32_50_c78164b4-40d2-11ed-a47b-a46bb6070c92.jpg"
x = int(name.split("_")[0])
y = int(name.split("_")[1])

img = Image.open("../data/room_1/" + name)
img = np.array(img)
cv2.circle(img, (x, y), 8, (0, 255, 0), 3)

plt.imshow(img)

Example of a input image of the track. Their is a green circle that gives the target variable.
Figure 1: example of input image of track (source: author)

These coordinates are the target variable. The model predicts them using the image as input. This prediction is then used to direct the car. In this case, you can see the car is coming up to a left turn. The ideal direction is to go towards the coordinates given by the green circle.

I want to focus on SHAP so we won’t go into too much depth on the modelling code. If you have any questions, feel free to ask them in the comments.

We start by creating the ImageDataset class. This is used to load our image data and target variables. It does this using the paths to our images. One thing to point out is how the target variables are scaled — both x and y will be between -1 and 1.

class ImageDataset(torch.utils.data.Dataset):
def __init__(self, paths, transform):

self.transform = transform
self.paths = paths

def __getitem__(self, idx):
"""Get image and target (x, y) coordinates"""

# Read image
path = self.paths[idx]
image = cv2.imread(path, cv2.IMREAD_COLOR)
image = Image.fromarray(image)

# Transform image
image = self.transform(image)

# Get target
target = self.get_target(path)
target = torch.Tensor(target)

return image, target

def get_target(self,path):
"""Get the target (x, y) coordinates from path"""

name = os.path.basename(path)
items = name.split('_')
x = items[0]
y = items[1]

# Scale between -1 and 1
x = 2.0 * (int(x)/ 224 - 0.5) # -1 left, +1 right
y = 2.0 * (int(y) / 244 -0.5)# -1 top, +1 bottom

return [x, y]

def __len__(self):
return len(self.paths)

In fact, when the model is deployed only the x predictions are used to direct the car. Because of scaling, the sign of the x prediction will determine the car’s direction. When x < 0, the car should turn left. Similarly, when x > 0 the car should turn right. The larger the x value the sharper the turn.

We use the ImageDataset class to create training and validation data loaders. This is done by doing a random 80/20 split of all the image paths from room 1. In the end, we have 1,217 and 305 images in the training and validation set respectively.

TRANSFORMS = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

paths = glob.glob('../data/room_1/*')

# Shuffle the paths
random.shuffle(paths)

# Create a datasets for training and validation
split = int(0.8 * len(paths))
train_data = ImageDataset(paths[:split], TRANSFORMS)
valid_data = ImageDataset(paths[split:], TRANSFORMS)

# Prepare data for Pytorch model
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())

Notice the batch_size of the valid_loader. We are using the length of the validation dataset (i.e. 305). This allows us to load all validation data in one iteration. If you are working with larger datasets you may need to use a smaller batch size.

We load a pretrained ResNet18 model (line 5). By setting model.fc, we update the final layer (line 6). It is a fully connected layer from 512 nodes to our 2 target variable nodes. We will be using the Adam optimizer to fine-tune this model (line 9).

output_dim = 2 # x, y
device = torch.device('mps') # or 'cuda' if you have a GPU

# RESNET 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters())

I’ve trained the model using a GPU (line 2). You will still be able to run the code on a CPU. Fine-tuning is not as computationally expensive as training from scratch!

Finally, we have our model training code. We train for 10 epochs using MSE as our loss function. Our final model is the one that has the lowest MSE on the validation set.

name = "direction_model_1" # Change this to save a new model

# Train the model
min_loss = np.inf
for epoch in range(10):

model = model.train()
for images, target in iter(train_loader):

images = images.to(device)
target = target.to(device)

# Zero gradients of parameters
optimizer.zero_grad()

# Execute model to get outputs
output = model(images)

# Calculate loss
loss = torch.nn.functional.mse_loss(output, target)

# Run backpropogation to accumulate gradients
loss.backward()

# Update model parameters
optimizer.step()

# Calculate validation loss
model = model.eval()

images, target = next(iter(valid_loader))
images = images.to(device)
target = target.to(device)

output = model(images)
valid_loss = torch.nn.functional.mse_loss(output, target)

print("Epoch: {}, Validation Loss: {}".format(epoch, valid_loss.item()))

if valid_loss < min_loss:
print("Saving model")
torch.save(model, '../models/{}.pth'.format(name))

min_loss = valid_loss

At this point, we want to understand how our model is doing. We look at MSE and scatter plots of actual vs predicted x values. We ignore y for now as it does not impact the direction of the car.

Training and validation set

Figure 2 gives these metrics on the training and validation set. The diagonal red line gives perfect predictions. There is a similar variation around this line for x < 0 and x > 0. In other words, the model is able to predict left and right turns with similar accuracy. Similar performance on the training and validation set also indicates that the model is not overfitted.

Two scatter plots of predicted vs actual x. The first is for the training set and the second is for the validation set. The MSE for each set is given at the top.
Figure 2: model evaluation on training and validation set (source: author)

To create the above plot, we use the model_evaluation function. Note, the data loaders should be created so that they will load all data in the first iteration.

def model_evaluation(loaders,labels,save_path = None):

"""Evaluate direction models with mse and scatter plots
loaders: list of data loaders
labels: list of labels for plot title"""

n = len(loaders)
fig, axs = plt.subplots(1, n, figsize=(7*n, 6))

# Evalution metrics
for i, loader in enumerate(loaders):

# Load all data
images, target = next(iter(loader))
images = images.to(device)
target = target.to(device)

output=model(images)

# Get x predictions
x_pred=output.detach().cpu().numpy()[:,0]
x_target=target.cpu().numpy()[:,0]

# Calculate MSE
mse = mean_squared_error(x_target, x_pred)

# Plot predcitons
axs[i].scatter(x_target,x_pred)
axs[i].plot([-1, 1],
[-1, 1],
color='r',
linestyle='-',
linewidth=2)

axs[i].set_ylabel('Predicted x', size =15)
axs[i].set_xlabel('Actual x', size =15)
axs[i].set_title("{0} MSE: {1:.4f}".format(labels[i], mse),size = 18)

if save_path != None:
fig.savefig(save_path)

You can see what we mean when we use the function below. We have created a new train_loader setting the batch size to the length of the training dataset. It is also important to load the saved model (line 2). Otherwise, you will end up using the model trained during the last epoch.

# Load saved model 
model = torch.load('../models/direction_model_1.pth')
model.eval()
model.to(device)

# Create new loader for all data
train_loader = DataLoader(train_data, batch_size=train_data.__len__())

# Evaluate model on training and validation set
loaders = [train_loader,valid_loader]
labels = ["Train","Validation"]

# Evaluate on training and validation set
model_evaluation(loaders,labels)

Moving to a new locations

The results look good! We would expect the car to perform well and it did. That is until we moved it to a new location:

Gif of the AI car making incorrect turns
Figure 3: model going wrong in a new location (source: author)

We collect some data from new locations (room 2 and room 3). Running the evaluation on these images, you can see that our model does not perform as well. This is strange! The car is on the exact same track so why does the room matter?

Two scatter plots of predicted vs actual x. The first is for data from room 2 and the second is for data from room 3. The MSE for each set is given at the top.
Figure 3: model evaluation on room 2 and room 3 (source: author)

We look to SHAP for the answer. It can be used to understand which pixels are important for a given prediction. We start by loading our saved model (line 2). SHAP has not been implemented for GPU so we set the device to CPU (lines 5–6).

# Load saved model 
model = torch.load('../models/direction_model_1.pth')

# Use CPU
device = torch.device('cpu')
model = model.to(device)

To calculate SHAP values, we need to get some background images. SHAP will integrate over these images when calculating values. We are using a batch_size of 100 images. This should give us reasonable approximations. Increasing the number of images will improve the approximation but it will also increase the computation time.

#Load 100 images for background
shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)
background, _ = next(iter(shap_loader))
background = background.to(device)

We create an explainer object by passing our model and background images into the DeepExplainer function. This function approximates SHAP values efficiently for neural networks. As an alternative, you could replace it with the GradientExplainer function.

#Create SHAP explainer 
explainer = shap.DeepExplainer(model, background)

We load 2 example images — a right and left turn (line 2) and transform them (line 6). This is important as the images should be in the same format as used to train the model. We then calculate the SHAP values for the predictions made using these images (line 10).

# Load test images of right and left turn
paths = glob.glob('../data/room_1/*')
test_images = [Image.open(paths[0]), Image.open(paths[3])]
test_images = np.array(test_images)

test_input = [TRANSFORMS(img) for img in test_images]
test_input = torch.stack(test_input).to(device)

# Get SHAP values
shap_values = explainer.shap_values(test_input)

Finally, we can display the SHAP values using the image_plot function. But, we first need to restructure them. The SHAP values are returned with dimensions:

( #targets, #images, #channels, #width, #height)

We use the transpose function so we have dimensions:

( #targets, #images, #width, #height, #channels)

Note, we have also passed the original images into the image_plot function. The test_input images would look strange due to the transformations.

# Reshape shap values and images for plotting
shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
test_numpy = np.array([np.array(img) for img in test_images])

shap.image_plot(shap_numpy, test_numpy,show=False)

You can see the result in Figure 4. The first column gives the original images. The second and third columns are the SHAP values for the x and y prediction respectively. Blue pixels have decreased the prediction. In comparison, red pixels have increased the prediction. In other words, for the x prediction, red pixels have resulted in a sharper right turn.

The first column is two images of a track for a mini-automated car. The first row is a right turn and the second row is a left turn. The other columns are the SHAP values from the prediction made using those images.
Figure 4: example shap values on a left an right turn (source: author)

Now we are getting somewhere. The important result is that the model is using background pixels. You can see this in Figure 5 where we zoom in on the x prediction for the right turn. In other words, the background is important to the prediction. That explains the poor performance! When we moved to a new room, the background changed and our predictions became unreliable.

SHAP values for the prediction made using the image of the right turn. The pixels in the background are highighted red and blue.
Figure 5: shap values for x prediction of right turn (source: author)

The model is overfitted to the data from room 1. The same objects and background are present in every image. As a result, the model associates these with left and right turns. We couldn’t identify this in our evaluation as we have the same background in both the training and validation images.

Illustration of an object present at a left turn. The car sees this object as it is coming up to the turn.
Figure 6: overfitting to training data (source: author)

We want our model to perform well under all conditions. To achieve this, we would expect it to only use pixels from the track. So, let’s discuss some ways of making the model more robust.

Collecting new data

The best solution is to simply collect more data. We already have some from room 2 and 3. Following the same process, we train a new model using data from all 3 rooms. Looking at Figure 7, it now has a better performance on images from the new rooms.

Figure 7: evaluation of the new model on rooms 2 and 3 (source: author)

The hope is that by training on data from multiple rooms we break the associations between turns and the background. Different objects are now present on left and right turns but the track remains the same. The model should learn that the track is what is important to the prediction.

We can confirm this by looking at the SHAP values for the new model. These are for the same turns we saw in Figure 4. There is now less weight put on the background pixels. Okay, it’s not perfect but we are getting somewhere.

The SHAP values for the model trained on data from all 3 rooms. The same right turn and left turn are given. Now less of the background pixels are highlighted.
Figure 8: shap values from model trained on all 3 rooms (source: author)

We could continue to collect data. The more locations we collect data the more robust our model will be. However, data collection can be time-consuming (and boring!). Instead, we can look to data augmentation.

Data Augmentations

Data augmentation is when we systematically or randomly alter images using code. This allows us to artificially introduce noise and increase the size of our dataset.

For example, we could double the size of our dataset by flipping images on the vertical axis. We can do this because our track is symmetrical. As seen in Figure 9, deletion could also be a useful method. This involves including images where objects or the entire background have been removed.

An example of image augmentation using deletion. In the first image, you can see a chair in the backgorund. In the second, the chair has been replaced with a black square.
Figure 9: example of image augmentation using deletion (source: author)

When building a robust model, you should also consider factors like lighting conditions and image quality. We can simulate these using color jitter or by adding noise. If you want to learn about all of these methods check out the article below.

In the above article, we also discuss why it is difficult to tell if these augmentations have made the model more robust. We could deploy the model in many environments but this is time-consuming. Thankfully, SHAP can be used as an alternative. Like with data collection, it can give us insight into how the augmentations have changed the way the model makes predictions.


The first column is two images of a track for a mini-automated car. The other columns are the SHAP values from the prediction made using those images.
(source: author)

Autonomous cars terrify me. Big hunks of metal flying around with no humans to stop them if something goes wrong. To reduce this risk it is not enough to evaluate the models powering these beasts. We also need to understand how they are making predictions. This is to avoid any edge cases that would cause unforeseen accidents.

Okay, so our application is not so consequential. We will be debugging the model used to power a mini-automated car (the worst you could expect is a bruised ankle). Still, IML methods can be useful. We will see how they can even improve the performance of the model.

Specifically, we will:

  • Fine-tune ResNet-18 using PyTorch with image data and a continuous target variable
  • Evaluate the model using MSE and scatter plots
  • Interpret the model using DeepSHAP
  • Correct the model through better data collection
  • Discuss how image augmentation could further improve the model

Along the way, we will discuss some key pieces of Python code. You can also find the full project on GitHub.

If you are new to SHAP and want the basics, see the article below. Otherwise, check out this SHAP course. You can get free access if you sign up for my Newsletter 🙂

# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
import random

from PIL import Image
import cv2

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

import shap
from sklearn.metrics import mean_squared_error

We start the project by collecting data in one room only (this will come back to haunt us). As mentioned, we use images to power an automated car. You can find examples of these on Kaggle. These images are all 224 x 224 pixels.

We display one of them with the code below. Take note of the image name (line 2). The first two numbers are x and y coordinates within the 224 x 224 frame. In Figure 1, you can see we have displayed these coordinates using a green circle (line 8).

#Load example image
name = "32_50_c78164b4-40d2-11ed-a47b-a46bb6070c92.jpg"
x = int(name.split("_")[0])
y = int(name.split("_")[1])

img = Image.open("../data/room_1/" + name)
img = np.array(img)
cv2.circle(img, (x, y), 8, (0, 255, 0), 3)

plt.imshow(img)

Example of a input image of the track. Their is a green circle that gives the target variable.
Figure 1: example of input image of track (source: author)

These coordinates are the target variable. The model predicts them using the image as input. This prediction is then used to direct the car. In this case, you can see the car is coming up to a left turn. The ideal direction is to go towards the coordinates given by the green circle.

I want to focus on SHAP so we won’t go into too much depth on the modelling code. If you have any questions, feel free to ask them in the comments.

We start by creating the ImageDataset class. This is used to load our image data and target variables. It does this using the paths to our images. One thing to point out is how the target variables are scaled — both x and y will be between -1 and 1.

class ImageDataset(torch.utils.data.Dataset):
def __init__(self, paths, transform):

self.transform = transform
self.paths = paths

def __getitem__(self, idx):
"""Get image and target (x, y) coordinates"""

# Read image
path = self.paths[idx]
image = cv2.imread(path, cv2.IMREAD_COLOR)
image = Image.fromarray(image)

# Transform image
image = self.transform(image)

# Get target
target = self.get_target(path)
target = torch.Tensor(target)

return image, target

def get_target(self,path):
"""Get the target (x, y) coordinates from path"""

name = os.path.basename(path)
items = name.split('_')
x = items[0]
y = items[1]

# Scale between -1 and 1
x = 2.0 * (int(x)/ 224 - 0.5) # -1 left, +1 right
y = 2.0 * (int(y) / 244 -0.5)# -1 top, +1 bottom

return [x, y]

def __len__(self):
return len(self.paths)

In fact, when the model is deployed only the x predictions are used to direct the car. Because of scaling, the sign of the x prediction will determine the car’s direction. When x < 0, the car should turn left. Similarly, when x > 0 the car should turn right. The larger the x value the sharper the turn.

We use the ImageDataset class to create training and validation data loaders. This is done by doing a random 80/20 split of all the image paths from room 1. In the end, we have 1,217 and 305 images in the training and validation set respectively.

TRANSFORMS = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

paths = glob.glob('../data/room_1/*')

# Shuffle the paths
random.shuffle(paths)

# Create a datasets for training and validation
split = int(0.8 * len(paths))
train_data = ImageDataset(paths[:split], TRANSFORMS)
valid_data = ImageDataset(paths[split:], TRANSFORMS)

# Prepare data for Pytorch model
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())

Notice the batch_size of the valid_loader. We are using the length of the validation dataset (i.e. 305). This allows us to load all validation data in one iteration. If you are working with larger datasets you may need to use a smaller batch size.

We load a pretrained ResNet18 model (line 5). By setting model.fc, we update the final layer (line 6). It is a fully connected layer from 512 nodes to our 2 target variable nodes. We will be using the Adam optimizer to fine-tune this model (line 9).

output_dim = 2 # x, y
device = torch.device('mps') # or 'cuda' if you have a GPU

# RESNET 18
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, output_dim)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters())

I’ve trained the model using a GPU (line 2). You will still be able to run the code on a CPU. Fine-tuning is not as computationally expensive as training from scratch!

Finally, we have our model training code. We train for 10 epochs using MSE as our loss function. Our final model is the one that has the lowest MSE on the validation set.

name = "direction_model_1" # Change this to save a new model

# Train the model
min_loss = np.inf
for epoch in range(10):

model = model.train()
for images, target in iter(train_loader):

images = images.to(device)
target = target.to(device)

# Zero gradients of parameters
optimizer.zero_grad()

# Execute model to get outputs
output = model(images)

# Calculate loss
loss = torch.nn.functional.mse_loss(output, target)

# Run backpropogation to accumulate gradients
loss.backward()

# Update model parameters
optimizer.step()

# Calculate validation loss
model = model.eval()

images, target = next(iter(valid_loader))
images = images.to(device)
target = target.to(device)

output = model(images)
valid_loss = torch.nn.functional.mse_loss(output, target)

print("Epoch: {}, Validation Loss: {}".format(epoch, valid_loss.item()))

if valid_loss < min_loss:
print("Saving model")
torch.save(model, '../models/{}.pth'.format(name))

min_loss = valid_loss

At this point, we want to understand how our model is doing. We look at MSE and scatter plots of actual vs predicted x values. We ignore y for now as it does not impact the direction of the car.

Training and validation set

Figure 2 gives these metrics on the training and validation set. The diagonal red line gives perfect predictions. There is a similar variation around this line for x < 0 and x > 0. In other words, the model is able to predict left and right turns with similar accuracy. Similar performance on the training and validation set also indicates that the model is not overfitted.

Two scatter plots of predicted vs actual x. The first is for the training set and the second is for the validation set. The MSE for each set is given at the top.
Figure 2: model evaluation on training and validation set (source: author)

To create the above plot, we use the model_evaluation function. Note, the data loaders should be created so that they will load all data in the first iteration.

def model_evaluation(loaders,labels,save_path = None):

"""Evaluate direction models with mse and scatter plots
loaders: list of data loaders
labels: list of labels for plot title"""

n = len(loaders)
fig, axs = plt.subplots(1, n, figsize=(7*n, 6))

# Evalution metrics
for i, loader in enumerate(loaders):

# Load all data
images, target = next(iter(loader))
images = images.to(device)
target = target.to(device)

output=model(images)

# Get x predictions
x_pred=output.detach().cpu().numpy()[:,0]
x_target=target.cpu().numpy()[:,0]

# Calculate MSE
mse = mean_squared_error(x_target, x_pred)

# Plot predcitons
axs[i].scatter(x_target,x_pred)
axs[i].plot([-1, 1],
[-1, 1],
color='r',
linestyle='-',
linewidth=2)

axs[i].set_ylabel('Predicted x', size =15)
axs[i].set_xlabel('Actual x', size =15)
axs[i].set_title("{0} MSE: {1:.4f}".format(labels[i], mse),size = 18)

if save_path != None:
fig.savefig(save_path)

You can see what we mean when we use the function below. We have created a new train_loader setting the batch size to the length of the training dataset. It is also important to load the saved model (line 2). Otherwise, you will end up using the model trained during the last epoch.

# Load saved model 
model = torch.load('../models/direction_model_1.pth')
model.eval()
model.to(device)

# Create new loader for all data
train_loader = DataLoader(train_data, batch_size=train_data.__len__())

# Evaluate model on training and validation set
loaders = [train_loader,valid_loader]
labels = ["Train","Validation"]

# Evaluate on training and validation set
model_evaluation(loaders,labels)

Moving to a new locations

The results look good! We would expect the car to perform well and it did. That is until we moved it to a new location:

Gif of the AI car making incorrect turns
Figure 3: model going wrong in a new location (source: author)

We collect some data from new locations (room 2 and room 3). Running the evaluation on these images, you can see that our model does not perform as well. This is strange! The car is on the exact same track so why does the room matter?

Two scatter plots of predicted vs actual x. The first is for data from room 2 and the second is for data from room 3. The MSE for each set is given at the top.
Figure 3: model evaluation on room 2 and room 3 (source: author)

We look to SHAP for the answer. It can be used to understand which pixels are important for a given prediction. We start by loading our saved model (line 2). SHAP has not been implemented for GPU so we set the device to CPU (lines 5–6).

# Load saved model 
model = torch.load('../models/direction_model_1.pth')

# Use CPU
device = torch.device('cpu')
model = model.to(device)

To calculate SHAP values, we need to get some background images. SHAP will integrate over these images when calculating values. We are using a batch_size of 100 images. This should give us reasonable approximations. Increasing the number of images will improve the approximation but it will also increase the computation time.

#Load 100 images for background
shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)
background, _ = next(iter(shap_loader))
background = background.to(device)

We create an explainer object by passing our model and background images into the DeepExplainer function. This function approximates SHAP values efficiently for neural networks. As an alternative, you could replace it with the GradientExplainer function.

#Create SHAP explainer 
explainer = shap.DeepExplainer(model, background)

We load 2 example images — a right and left turn (line 2) and transform them (line 6). This is important as the images should be in the same format as used to train the model. We then calculate the SHAP values for the predictions made using these images (line 10).

# Load test images of right and left turn
paths = glob.glob('../data/room_1/*')
test_images = [Image.open(paths[0]), Image.open(paths[3])]
test_images = np.array(test_images)

test_input = [TRANSFORMS(img) for img in test_images]
test_input = torch.stack(test_input).to(device)

# Get SHAP values
shap_values = explainer.shap_values(test_input)

Finally, we can display the SHAP values using the image_plot function. But, we first need to restructure them. The SHAP values are returned with dimensions:

( #targets, #images, #channels, #width, #height)

We use the transpose function so we have dimensions:

( #targets, #images, #width, #height, #channels)

Note, we have also passed the original images into the image_plot function. The test_input images would look strange due to the transformations.

# Reshape shap values and images for plotting
shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
test_numpy = np.array([np.array(img) for img in test_images])

shap.image_plot(shap_numpy, test_numpy,show=False)

You can see the result in Figure 4. The first column gives the original images. The second and third columns are the SHAP values for the x and y prediction respectively. Blue pixels have decreased the prediction. In comparison, red pixels have increased the prediction. In other words, for the x prediction, red pixels have resulted in a sharper right turn.

The first column is two images of a track for a mini-automated car. The first row is a right turn and the second row is a left turn. The other columns are the SHAP values from the prediction made using those images.
Figure 4: example shap values on a left an right turn (source: author)

Now we are getting somewhere. The important result is that the model is using background pixels. You can see this in Figure 5 where we zoom in on the x prediction for the right turn. In other words, the background is important to the prediction. That explains the poor performance! When we moved to a new room, the background changed and our predictions became unreliable.

SHAP values for the prediction made using the image of the right turn. The pixels in the background are highighted red and blue.
Figure 5: shap values for x prediction of right turn (source: author)

The model is overfitted to the data from room 1. The same objects and background are present in every image. As a result, the model associates these with left and right turns. We couldn’t identify this in our evaluation as we have the same background in both the training and validation images.

Illustration of an object present at a left turn. The car sees this object as it is coming up to the turn.
Figure 6: overfitting to training data (source: author)

We want our model to perform well under all conditions. To achieve this, we would expect it to only use pixels from the track. So, let’s discuss some ways of making the model more robust.

Collecting new data

The best solution is to simply collect more data. We already have some from room 2 and 3. Following the same process, we train a new model using data from all 3 rooms. Looking at Figure 7, it now has a better performance on images from the new rooms.

Figure 7: evaluation of the new model on rooms 2 and 3 (source: author)

The hope is that by training on data from multiple rooms we break the associations between turns and the background. Different objects are now present on left and right turns but the track remains the same. The model should learn that the track is what is important to the prediction.

We can confirm this by looking at the SHAP values for the new model. These are for the same turns we saw in Figure 4. There is now less weight put on the background pixels. Okay, it’s not perfect but we are getting somewhere.

The SHAP values for the model trained on data from all 3 rooms. The same right turn and left turn are given. Now less of the background pixels are highlighted.
Figure 8: shap values from model trained on all 3 rooms (source: author)

We could continue to collect data. The more locations we collect data the more robust our model will be. However, data collection can be time-consuming (and boring!). Instead, we can look to data augmentation.

Data Augmentations

Data augmentation is when we systematically or randomly alter images using code. This allows us to artificially introduce noise and increase the size of our dataset.

For example, we could double the size of our dataset by flipping images on the vertical axis. We can do this because our track is symmetrical. As seen in Figure 9, deletion could also be a useful method. This involves including images where objects or the entire background have been removed.

An example of image augmentation using deletion. In the first image, you can see a chair in the backgorund. In the second, the chair has been replaced with a black square.
Figure 9: example of image augmentation using deletion (source: author)

When building a robust model, you should also consider factors like lighting conditions and image quality. We can simulate these using color jitter or by adding noise. If you want to learn about all of these methods check out the article below.

In the above article, we also discuss why it is difficult to tell if these augmentations have made the model more robust. We could deploy the model in many environments but this is time-consuming. Thankfully, SHAP can be used as an alternative. Like with data collection, it can give us insight into how the augmentations have changed the way the model makes predictions.

FOLLOW US ON GOOGLE NEWS

Read original article here

Denial of responsibility! Techno Blender is an automatic aggregator of the all world’s media. In each content, the hyperlink to the primary source is specified. All trademarks belong to their rightful owners, all materials to their authors. If you are the owner of the content and do not want us to publish your materials, please contact us by email – [email protected]. The content will be deleted within 24 hours.

Leave a comment