Techno Blender
Digitally Yours.

Image Classification with PyTorch and SHAP: Can you Trust an Automated Car? | by Conor O’Sullivan | Mar, 2023

0 45


Build an object detection model, compare it to intensity thresholds, evaluate it and explain it using DeepSHAP

Image of a tin can with the pixels highlighted blue and red using the SHAP Python package
(source: author)

If the world was less chaotic self-driving cars would be simple. But it’s not. To avoid serious harm, AI has to consider many variables — speed limits, traffic and obstacles in the road (such as a distracted human). AI needs to be able to detect these obstacles and take appropriate actions when encountered.

Thankfully, our application is not as complicated. Even more, thankfully, we will be using tin cans instead of humans. We will build a model used to detect this obstacle in front of a mini-automated car. The car should STOP if the obstacle gets too close or GO otherwise.

At the end of the day, this is a binary classification problem. To tackle it, we will:

  • Create a benchmark using an intensity threshold
  • Build a CNN using PyTorch
  • Evaluate the model using accuracy, precision and recall
  • Interpret the model using SHAP

We will see that the model not only performs well but the way it makes predictions also seems reasonable. Along the way, we will discuss the Python code and you can find the full project on GitHub.

# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
import random

from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

import shap
from sklearn import metrics
from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.metrics import ConfusionMatrixDisplay as cmd

In Figure 1 you can see examples of the images in our dataset. These are all of dimension 224 x 224. If there is no black can or if the can is far away the image is classified as GO. If the can gets too close, the image is classified as STOP. You can find the full dataset on Kaggle.

Example images from the training dataset. The first two are labelled as go and the can is far away. The third is labelled as stop and the can is close.
Figure 1: example images (source: author)

We display the above images using the code below. Notice the names of the images. It will always start with a number. This is the target variable. We have 0 for GO and 1 for STOP.

# Paths of example images
ex_paths = ["../../data/object_detection/0_b812cd70-4eff-11ed-9b15-f602a686e36d.jpg",
"../../data/object_detection/0_d1edcc80-4ef6-11ed-8ddf-a46bb6070c92.jpg",
"../../data/object_detection/1_cb171726-4ef7-11ed-8ddf-a46bb6070c92.jpg"]

# Plot example images
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.set_facecolor('white')

for i, path in enumerate(ex_paths):

# Load image
img = Image.open(path)

# Get target
name = path.split("/")[-1]
target = int(name.split("_")[0])

# Plot image
ax[i].imshow(img)
ax[i].axis("off")

# Set title
title = ["GO","STOP"][target]
ax[i].set_title(title,size=20)

Before we get to the modelling, it is worth creating a benchmark. This can provide some insight into our problem. More importantly, it gives us something to compare our model results to. Our more complicated deep learning model should outperform the simple benchmark.

In Figure 1 we can see that the tin can is darker than it’s surroundings. We’re going to take advantage of this when creating our benchmark. That is we will classify an image as STOP if it has many dark pixels. Getting to that point will require a few steps. For each image, we will:

  1. Greyscale so each pixel has a value between 0 (black) and 255 (white)
  2. Using a cutoff, convert each pixel to a binary value — 1 for dark pixels and 0 for light
  3. Calculate average intensity — the percentage of dark pixels
  4. If the average intensity is above a certain percentage, we classify the image as STOP

Combined steps 1 and 2 is a type of feature engineering method for image data. It is known as an intensity threshold. You can read more about this and other feature engineering methods in this article:

We apply the intensity threshold using the function below. After scaling, a pixel will have a value of either 0 (black) and 1 (white). For our application, it makes sense to invert this. That is so pixels that are originally dark will be given a value of 1.

def threshold(img,cutoff,invert=False):
"""Apply intesity thresholding"""

img = np.array(img)

# Greyscale image
img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)

#Apply cutoff
img[img>cutoff] = 255 #white
img[img<=cutoff] = 0 #black

# Scale to 0-1
img = img/255

# Invert image so black = 1
if invert:
img = 1 - img

return img

In Figure 2, you can see some examples of when we apply the intensity threshold. We are able to vary the cutoff. A smaller cutoff means we include less background noise. The downside is we capture less of the tin can. In this case, we’ll go with a cutoff of 60.

Intesity threshold applied to images of tin cans.
Figure 2: feature engineering with an intensity threshold (source: author)

We load all our images (line 5) and target variables (line 6). We then apply the intensity threshold to each of these images (line 9). Note that we have set invert=True. Finally, we calculate the average intensity of each of the processed images (line 10). In the end, each of the images is represented by a single number — average intensity. This can be interpreted as the percentage of dark pixels.

# Load paths
paths = glob.glob("../../data/object_detection/*.jpg")

# Load images and targets
images = [Image.open(path) for path in paths]
target = [int(path.split("/")[-1].split("_")[0]) for path in paths]

# Apply thresholding and get intensity
thresh_img = [threshold(img,60,True) for img in images]
intensity = [np.average(img) for img in thresh_img]

Figure 3 gives the box plots of the average intensity for all the images labelled as GO and STOP. In general, we can see the values are higher for STOP. This makes sense — the can is closer and so we will have more dark pixels. The red line is at 6.5%. This seems to separate the image classes well.

Average intensity for target variable label. Images labelled as go tend to have a higher intensity as they have more dark pixels.
Figure 3: average intensity by target variable (source: author)
# Split data into go and stop images
go_data = [intensity[i] for i in range(len(target)) if target[i]==0]
stop_data = [intensity[i] for i in range(len(target)) if target[i]==1]
data= [go_data,stop_data]

fig = plt.figure(figsize=(5,5))

# Plot boxplot
plt.boxplot(data)
plt.hlines(y=0.065,xmin=0.5,xmax=2.5,color='r')
plt.xticks([1,2],['GO','STOP'])
plt.ylabel("Average Intensity",size=15)

We use a cutoff of 6.5% to make predictions (line 2). That is if the percentage of dark pixels is above 6.5% it is predicted as STOP (1) otherwise we predict GO (0). The remaining code is used to evaluate these predictions.

# Predict using average intensity
prediction = [1 if i>0.065 else 0 for i in intensity]

# Evaluate
acc = metrics.accuracy_score(target,prediction)
prec,rec,_,_ = score(target, prediction,average='macro')

print('Accuracy: {}'.format(round(acc,4)))
print('Precision: {}'.format(round(prec,4)))
print('Recall: {}'.format(round(rec,4)))

# Plot confusion matrix
cm = metrics.confusion_matrix(target, prediction)
cm_display = cmd(cm, display_labels = ['GO', 'STOP'])

cm_display.plot()

In the end, we have an accuracy of 82%, a precision of 77.1% and a recall of 82.96%. Not bad! In the confusion matrix, we can see that most of the errors are due to false positives. These are images predicted as STOP when we should GO. This corresponds to the box plot in Figure 3. See the long tail for the GO intensity values that are above the red line. This could potentially be caused by background pixels increasing the number of dark pixels in the image.

Confusion matrix of benchmark predictions
Figure 4: confusion matrix of benchmark predictions (source: author)

If an AI car was only 82% accurate you’d probably be a bit concerned. So let’s move on to the more complicated solution.

Loading dataset

We start by defining the ImageDataset class. This is used to load our images and target variables. As parameters, we need to pass in the list of all image paths and methods used to transform the images. Our target variables will be tensors — [1,0] for GO and [0,1] for STOP.

class ImageDataset(torch.utils.data.Dataset):
def __init__(self, paths, transform):

self.paths = paths
self.transform = transform

def __getitem__(self, idx):
"""Get image and target (x, y) coordinates"""

# Read image
path = self.paths[idx]
image = cv2.imread(path, cv2.IMREAD_COLOR)
image = Image.fromarray(image)

# Transform image
image = self.transform(image)

# Get target
target = path.split("/")[-1].split("_")[0]
target = [[1,0],[0,1]][int(target)]

target = torch.Tensor(target)

return image, target

def __len__(self):
return len(self.paths)

We will use common image transformations. To help create a more robust model, we will Jitter the color (line 2). This will randomly vary the brightness, contrast, saturation and hue of the image. We also normalise the pixel values (line 4). This will help the model converge.

TRANSFORMS = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

We load all our image paths (line 1) and randomly shuffle them (line 4). We then create ImageDataset objects for our training (line 8) and validation data (line 9). To do this we’ve used an 80/20 split (line 7). In the end, we will have 3,892 images in the training set and 974 images in the validation set.

paths = glob.glob("../../data/object_detection/*.jpg")

# Shuffle the paths
random.shuffle(paths)

# Create a datasets for training and validation
split = int(0.8 * len(paths))
train_data = ImageDataset(paths[:split], TRANSFORMS)
valid_data = ImageDataset(paths[split:], TRANSFORMS)

At this point, there is actually no data loaded to memory. Before we can use the data to train PyTorch models, we need to create DataLoader objects. For the train_loader, we have set batch_size=128. This allows us to iterate over all the training images loading 128 of them at a time. For the validation images, we set the batch size to the full length of the validation set. This allows us to load all 974 images at once.

# Prepare data for Pytorch model
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())

Model architecture

Next we define our CNN architecture. You can see a diagram of this in Figure 5. We start with 224x224x3 image tensors. We have 3 convolutional and max pooling layers. This leaves us with 28x28x64 tensors. This is followed by a drop-out layer and two fully connected layers. We use ReLu activation functions for all hidden layers. For the output nodes, we use the sigmoid function. This is so our predictions will be between 0 and 1.

CNN architecture diagram. It includes convolutional, max pooling, fully connected and a sigmoid activation function layers.
Figure 5: CNN architecture (source: author)

We capture this architecture in the Net class below. One thing to point out is the use of nn.Sequential() functions. You must use this method of defining PyTorch models otherwise the SHAP package will not work.

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

# Convolutional layers
self.conv_layers = nn.Sequential(
# Sees 224x224x3 image tensor
nn.Conv2d(3, #RGB channels
16, #number of kernels
3, #size of kernels
padding=1),
nn.MaxPool2d(2),
nn.ReLU(),

# Sees 112x112x16 tensor
nn.Conv2d(16, 32, 3, padding=1),
nn.MaxPool2d(2),
nn.ReLU(),

# Sees 56x56x32 tensor
nn.Conv2d(32, 64, 3, padding=1),
nn.MaxPool2d(2),
nn.ReLU()
)

# Fully connected layers
self.fc_layers = nn.Sequential(
# Sees flattened 28 * 28 * 64 tensor
nn.Dropout(0.25),
nn.Linear(64 * 28 * 28, 500),
nn.ReLU(),
nn.Linear(500, 2),
nn.Sigmoid()
)

def forward(self, x):
x = self.conv_layers(x)
x = x.view(-1, 64 * 28 * 28)
x = self.fc_layers(x)
return x

We create a model object (line 2). We move this to a GPU (lines 6–7). I am using an Apple M1 laptop. You will have to set the device that is appropriate for your machine.

# create a complete CNN
model = Net()
print(model)

# move tensors to GPU if available
device = torch.device('mps')
model.to(device)

We define our loss function (line 2). As our target variable is binary, we will use binary cross-entropy loss. Lastly, we use Adam as our optimizer (line 5). This is the algorithm used to minimize the loss.

# specify loss function (binary cross-entropy)
criterion = nn.BCELoss()

# specify optimizer
optimizer = torch.optim.Adam(model.parameters())

Training model

Now for the fun part! We train our model for 20 epochs and select the one that had the lowest validation loss. The model is available in the same GitHub Repo.

name = "object_detection_cnn" # Change this to save a new model

# Train the model
min_loss = np.inf
for epoch in range(20):

model = model.train()
for images, target in iter(train_loader):

images = images.to(device)
target = target.to(device)

# Zero gradients of parameters
optimizer.zero_grad()

# Execute model to get outputs
output = model(images)

# Calculate loss
loss = criterion(output, target)

# Run backpropogation to accumulate gradients
loss.backward()

# Update model parameters
optimizer.step()

# Calculate validation loss
model = model.eval()

images, target = next(iter(valid_loader))
images = images.to(device)
target = target.to(device)

output = model(images)
valid_loss = criterion(output, target)

print("Epoch: {}, Validation Loss: {}".format(epoch, valid_loss.item()))

# Save model with lowest loss
if valid_loss < min_loss:
print("Saving model")
torch.save(model, '../../models/{}.pth'.format(name))

min_loss = valid_loss

One thing to mention is the optimizer.zero_grad() line. This sets the gradients of all parameters to 0. For each training iteration, we want to update the parameters using the gradients from only that batch. If we do not zero the gradients they will accumulate. This means we will update the parameters using a combination of the gradients from the new and old batches.

Model Evaluation

Now lets see how well this model has done. We start by loading our saved model (line 2). It is important to switch this to evaluation mode (line 3). If we do not do this some model layers (e.g. dropout) will be used incorrectly for inference.

# Load saved model 
model = torch.load('../../models/object_detection_cnn.pth')
model.eval()
model.to(device)

We load the images and target variables from the validation set (line 2). Remember, the target variables are tensors of dimension 2. We get the second element for each of the tensors (line 4). This means we will now have a binary target variable — 1 for STOP and 0 for GO.

# Get images and targets
images, target = next(iter(valid_loader))
images = images.to(device)
target = [int(t[1]) for t in target]

We use our model to make predictions on the validation images (line 2). Again, the output will be tensors of dimension 2. We consider the second element. If the probability is above 0.5 we predict STOP otherwise we predict GO.

# Get predictions
output=model(images)
prediction = [1 if o[1] > 0.5 else 0 for o in output]

Finally, we compare the target to the prediction using the same code we used to evaluate the benchmark. We now have an accuracy of 98.05%, a precision of 97.38% and a recall of 97.5%. A significant improvement over the benchmark! In the confusion matrix, you can see where the errors are coming from.

Confusion matrix of the model on the validation set
Figure 6: confusion matrix of the model on the validation set (source: author)

In Figure 7, we take a closer look at some of these errors. The top row gives some false positives. These are images predicted at STOP when the car should GO. Similarly, the bottom row gives false negatives.

Examples of the models false positives and false negatives on tin can object detection predictions
Figure 7: examples of prediction errors (source: author)

You may have noticed that all the obstacles are at a similar distance. When the images were labelled we used a cutoff distance. That is once the obstacle was closer than this cutoff it was labelled STOP. The above obstacles are all close to this cutoff. They could have been incorrectly labelled so the model may be “confused” when the obstacle is close to this cutoff.

Our model seems to be doing well. We can be even more certain by understanding how it makes these predictions. To do this we use SHAP. If you are new to SHAP, you may find the video below useful. Otherwise, check out my SHAP course. You can get free access if you sign up for my Newsletter 🙂

The code below calculates and displays the SHAP values for the 3 example images we saw in Figure 1. If you want more detail on how this code works, then check out the article mentioned at the end.

# Load saved model 
model = torch.load('../../models/object_detection_cnn.pth')

# Use CPU
device = torch.device('cpu')
model = model.to(device)

#Load 100 images for background
shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)
background, _ = next(iter(shap_loader))
background = background.to(device)

#Create SHAP explainer
explainer = shap.DeepExplainer(model, background)

# Load test images
test_images = [Image.open(path) for path in ex_paths]
test_images = np.array(test_images)

test_input = [TRANSFORMS(img) for img in test_images]
test_input = torch.stack(test_input).to(device)

# Get SHAP values
shap_values = explainer.shap_values(test_input)

# Reshape shap values and images for plotting
shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
test_numpy = np.array([np.array(img) for img in test_images])

shap.image_plot(shap_numpy, test_numpy,show=False)

You can see the output in Figure 8. The first two rows are images labelled as GO and the third is labelled as STOP. We have SHAP values for each of the elements in the target tensors. The first column is for the GO prediction and the second is for the STOP prediction.

The colours are important. Blue SHAP values tell us that those pixels have decreased the predicted value. In other words, they have made it less likely that the model predicts the given label. Similarly, Red SHAP values have increased the likelihood.

SHAP values for a PyTorch image classification model
Figure 8: SHAP values for example images (source: author)

To understand this, let’s focus on the top right corner of Figure 8. In Figure 9, we have the image labelled as GO and the SHAP values from the GO prediction. You can see that the majority of the pixels are red. These have increased the value for this prediction leading to a correct GO prediction. You can also see that the pixels are clusters around the obstacle cutoff — tin can location where the label changes from GO to STOP.

Close up on the SHAP values for the GO label and GO prediction
Figure 9: SHAP values for GO prediction and GO label.

In Figure 10, we can see the SHAP values for the image labelled as STOP. The can is blue for the GO prediction and red for the STOP prediction. In other words, the model is using pixels from the can to decrease the GO value and increase the STOP value. That makes sense!

Close up on the SHAP values for the STOP label. The SHAP values are blue for the GO prediction and red for the STOP prediction.
Figure 10: SHAP values for STOP prediction

The model is not only making predictions accurately but the way it is making those predictions seems logical. However, one thing you may have noticed is that some of the background pixels are highlighted. This doesn’t make sense. Why would the background be important to a prediction? It could change when we remove objects or move to a new location.

The reason is the model has become overfitted to the training data. These objects are present in many of the images. The result is the model associates them with the STOP/GO labels. In the article below, we do a similar analysis. We discuss ways how to prevent this type of overfitting. We also spend more time explaining the SHAP code.


Build an object detection model, compare it to intensity thresholds, evaluate it and explain it using DeepSHAP

Image of a tin can with the pixels highlighted blue and red using the SHAP Python package
(source: author)

If the world was less chaotic self-driving cars would be simple. But it’s not. To avoid serious harm, AI has to consider many variables — speed limits, traffic and obstacles in the road (such as a distracted human). AI needs to be able to detect these obstacles and take appropriate actions when encountered.

Thankfully, our application is not as complicated. Even more, thankfully, we will be using tin cans instead of humans. We will build a model used to detect this obstacle in front of a mini-automated car. The car should STOP if the obstacle gets too close or GO otherwise.

At the end of the day, this is a binary classification problem. To tackle it, we will:

  • Create a benchmark using an intensity threshold
  • Build a CNN using PyTorch
  • Evaluate the model using accuracy, precision and recall
  • Interpret the model using SHAP

We will see that the model not only performs well but the way it makes predictions also seems reasonable. Along the way, we will discuss the Python code and you can find the full project on GitHub.

# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
import random

from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

import shap
from sklearn import metrics
from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.metrics import ConfusionMatrixDisplay as cmd

In Figure 1 you can see examples of the images in our dataset. These are all of dimension 224 x 224. If there is no black can or if the can is far away the image is classified as GO. If the can gets too close, the image is classified as STOP. You can find the full dataset on Kaggle.

Example images from the training dataset. The first two are labelled as go and the can is far away. The third is labelled as stop and the can is close.
Figure 1: example images (source: author)

We display the above images using the code below. Notice the names of the images. It will always start with a number. This is the target variable. We have 0 for GO and 1 for STOP.

# Paths of example images
ex_paths = ["../../data/object_detection/0_b812cd70-4eff-11ed-9b15-f602a686e36d.jpg",
"../../data/object_detection/0_d1edcc80-4ef6-11ed-8ddf-a46bb6070c92.jpg",
"../../data/object_detection/1_cb171726-4ef7-11ed-8ddf-a46bb6070c92.jpg"]

# Plot example images
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
fig.set_facecolor('white')

for i, path in enumerate(ex_paths):

# Load image
img = Image.open(path)

# Get target
name = path.split("/")[-1]
target = int(name.split("_")[0])

# Plot image
ax[i].imshow(img)
ax[i].axis("off")

# Set title
title = ["GO","STOP"][target]
ax[i].set_title(title,size=20)

Before we get to the modelling, it is worth creating a benchmark. This can provide some insight into our problem. More importantly, it gives us something to compare our model results to. Our more complicated deep learning model should outperform the simple benchmark.

In Figure 1 we can see that the tin can is darker than it’s surroundings. We’re going to take advantage of this when creating our benchmark. That is we will classify an image as STOP if it has many dark pixels. Getting to that point will require a few steps. For each image, we will:

  1. Greyscale so each pixel has a value between 0 (black) and 255 (white)
  2. Using a cutoff, convert each pixel to a binary value — 1 for dark pixels and 0 for light
  3. Calculate average intensity — the percentage of dark pixels
  4. If the average intensity is above a certain percentage, we classify the image as STOP

Combined steps 1 and 2 is a type of feature engineering method for image data. It is known as an intensity threshold. You can read more about this and other feature engineering methods in this article:

We apply the intensity threshold using the function below. After scaling, a pixel will have a value of either 0 (black) and 1 (white). For our application, it makes sense to invert this. That is so pixels that are originally dark will be given a value of 1.

def threshold(img,cutoff,invert=False):
"""Apply intesity thresholding"""

img = np.array(img)

# Greyscale image
img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)

#Apply cutoff
img[img>cutoff] = 255 #white
img[img<=cutoff] = 0 #black

# Scale to 0-1
img = img/255

# Invert image so black = 1
if invert:
img = 1 - img

return img

In Figure 2, you can see some examples of when we apply the intensity threshold. We are able to vary the cutoff. A smaller cutoff means we include less background noise. The downside is we capture less of the tin can. In this case, we’ll go with a cutoff of 60.

Intesity threshold applied to images of tin cans.
Figure 2: feature engineering with an intensity threshold (source: author)

We load all our images (line 5) and target variables (line 6). We then apply the intensity threshold to each of these images (line 9). Note that we have set invert=True. Finally, we calculate the average intensity of each of the processed images (line 10). In the end, each of the images is represented by a single number — average intensity. This can be interpreted as the percentage of dark pixels.

# Load paths
paths = glob.glob("../../data/object_detection/*.jpg")

# Load images and targets
images = [Image.open(path) for path in paths]
target = [int(path.split("/")[-1].split("_")[0]) for path in paths]

# Apply thresholding and get intensity
thresh_img = [threshold(img,60,True) for img in images]
intensity = [np.average(img) for img in thresh_img]

Figure 3 gives the box plots of the average intensity for all the images labelled as GO and STOP. In general, we can see the values are higher for STOP. This makes sense — the can is closer and so we will have more dark pixels. The red line is at 6.5%. This seems to separate the image classes well.

Average intensity for target variable label. Images labelled as go tend to have a higher intensity as they have more dark pixels.
Figure 3: average intensity by target variable (source: author)
# Split data into go and stop images
go_data = [intensity[i] for i in range(len(target)) if target[i]==0]
stop_data = [intensity[i] for i in range(len(target)) if target[i]==1]
data= [go_data,stop_data]

fig = plt.figure(figsize=(5,5))

# Plot boxplot
plt.boxplot(data)
plt.hlines(y=0.065,xmin=0.5,xmax=2.5,color='r')
plt.xticks([1,2],['GO','STOP'])
plt.ylabel("Average Intensity",size=15)

We use a cutoff of 6.5% to make predictions (line 2). That is if the percentage of dark pixels is above 6.5% it is predicted as STOP (1) otherwise we predict GO (0). The remaining code is used to evaluate these predictions.

# Predict using average intensity
prediction = [1 if i>0.065 else 0 for i in intensity]

# Evaluate
acc = metrics.accuracy_score(target,prediction)
prec,rec,_,_ = score(target, prediction,average='macro')

print('Accuracy: {}'.format(round(acc,4)))
print('Precision: {}'.format(round(prec,4)))
print('Recall: {}'.format(round(rec,4)))

# Plot confusion matrix
cm = metrics.confusion_matrix(target, prediction)
cm_display = cmd(cm, display_labels = ['GO', 'STOP'])

cm_display.plot()

In the end, we have an accuracy of 82%, a precision of 77.1% and a recall of 82.96%. Not bad! In the confusion matrix, we can see that most of the errors are due to false positives. These are images predicted as STOP when we should GO. This corresponds to the box plot in Figure 3. See the long tail for the GO intensity values that are above the red line. This could potentially be caused by background pixels increasing the number of dark pixels in the image.

Confusion matrix of benchmark predictions
Figure 4: confusion matrix of benchmark predictions (source: author)

If an AI car was only 82% accurate you’d probably be a bit concerned. So let’s move on to the more complicated solution.

Loading dataset

We start by defining the ImageDataset class. This is used to load our images and target variables. As parameters, we need to pass in the list of all image paths and methods used to transform the images. Our target variables will be tensors — [1,0] for GO and [0,1] for STOP.

class ImageDataset(torch.utils.data.Dataset):
def __init__(self, paths, transform):

self.paths = paths
self.transform = transform

def __getitem__(self, idx):
"""Get image and target (x, y) coordinates"""

# Read image
path = self.paths[idx]
image = cv2.imread(path, cv2.IMREAD_COLOR)
image = Image.fromarray(image)

# Transform image
image = self.transform(image)

# Get target
target = path.split("/")[-1].split("_")[0]
target = [[1,0],[0,1]][int(target)]

target = torch.Tensor(target)

return image, target

def __len__(self):
return len(self.paths)

We will use common image transformations. To help create a more robust model, we will Jitter the color (line 2). This will randomly vary the brightness, contrast, saturation and hue of the image. We also normalise the pixel values (line 4). This will help the model converge.

TRANSFORMS = transforms.Compose([
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

We load all our image paths (line 1) and randomly shuffle them (line 4). We then create ImageDataset objects for our training (line 8) and validation data (line 9). To do this we’ve used an 80/20 split (line 7). In the end, we will have 3,892 images in the training set and 974 images in the validation set.

paths = glob.glob("../../data/object_detection/*.jpg")

# Shuffle the paths
random.shuffle(paths)

# Create a datasets for training and validation
split = int(0.8 * len(paths))
train_data = ImageDataset(paths[:split], TRANSFORMS)
valid_data = ImageDataset(paths[split:], TRANSFORMS)

At this point, there is actually no data loaded to memory. Before we can use the data to train PyTorch models, we need to create DataLoader objects. For the train_loader, we have set batch_size=128. This allows us to iterate over all the training images loading 128 of them at a time. For the validation images, we set the batch size to the full length of the validation set. This allows us to load all 974 images at once.

# Prepare data for Pytorch model
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())

Model architecture

Next we define our CNN architecture. You can see a diagram of this in Figure 5. We start with 224x224x3 image tensors. We have 3 convolutional and max pooling layers. This leaves us with 28x28x64 tensors. This is followed by a drop-out layer and two fully connected layers. We use ReLu activation functions for all hidden layers. For the output nodes, we use the sigmoid function. This is so our predictions will be between 0 and 1.

CNN architecture diagram. It includes convolutional, max pooling, fully connected and a sigmoid activation function layers.
Figure 5: CNN architecture (source: author)

We capture this architecture in the Net class below. One thing to point out is the use of nn.Sequential() functions. You must use this method of defining PyTorch models otherwise the SHAP package will not work.

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

# Convolutional layers
self.conv_layers = nn.Sequential(
# Sees 224x224x3 image tensor
nn.Conv2d(3, #RGB channels
16, #number of kernels
3, #size of kernels
padding=1),
nn.MaxPool2d(2),
nn.ReLU(),

# Sees 112x112x16 tensor
nn.Conv2d(16, 32, 3, padding=1),
nn.MaxPool2d(2),
nn.ReLU(),

# Sees 56x56x32 tensor
nn.Conv2d(32, 64, 3, padding=1),
nn.MaxPool2d(2),
nn.ReLU()
)

# Fully connected layers
self.fc_layers = nn.Sequential(
# Sees flattened 28 * 28 * 64 tensor
nn.Dropout(0.25),
nn.Linear(64 * 28 * 28, 500),
nn.ReLU(),
nn.Linear(500, 2),
nn.Sigmoid()
)

def forward(self, x):
x = self.conv_layers(x)
x = x.view(-1, 64 * 28 * 28)
x = self.fc_layers(x)
return x

We create a model object (line 2). We move this to a GPU (lines 6–7). I am using an Apple M1 laptop. You will have to set the device that is appropriate for your machine.

# create a complete CNN
model = Net()
print(model)

# move tensors to GPU if available
device = torch.device('mps')
model.to(device)

We define our loss function (line 2). As our target variable is binary, we will use binary cross-entropy loss. Lastly, we use Adam as our optimizer (line 5). This is the algorithm used to minimize the loss.

# specify loss function (binary cross-entropy)
criterion = nn.BCELoss()

# specify optimizer
optimizer = torch.optim.Adam(model.parameters())

Training model

Now for the fun part! We train our model for 20 epochs and select the one that had the lowest validation loss. The model is available in the same GitHub Repo.

name = "object_detection_cnn" # Change this to save a new model

# Train the model
min_loss = np.inf
for epoch in range(20):

model = model.train()
for images, target in iter(train_loader):

images = images.to(device)
target = target.to(device)

# Zero gradients of parameters
optimizer.zero_grad()

# Execute model to get outputs
output = model(images)

# Calculate loss
loss = criterion(output, target)

# Run backpropogation to accumulate gradients
loss.backward()

# Update model parameters
optimizer.step()

# Calculate validation loss
model = model.eval()

images, target = next(iter(valid_loader))
images = images.to(device)
target = target.to(device)

output = model(images)
valid_loss = criterion(output, target)

print("Epoch: {}, Validation Loss: {}".format(epoch, valid_loss.item()))

# Save model with lowest loss
if valid_loss < min_loss:
print("Saving model")
torch.save(model, '../../models/{}.pth'.format(name))

min_loss = valid_loss

One thing to mention is the optimizer.zero_grad() line. This sets the gradients of all parameters to 0. For each training iteration, we want to update the parameters using the gradients from only that batch. If we do not zero the gradients they will accumulate. This means we will update the parameters using a combination of the gradients from the new and old batches.

Model Evaluation

Now lets see how well this model has done. We start by loading our saved model (line 2). It is important to switch this to evaluation mode (line 3). If we do not do this some model layers (e.g. dropout) will be used incorrectly for inference.

# Load saved model 
model = torch.load('../../models/object_detection_cnn.pth')
model.eval()
model.to(device)

We load the images and target variables from the validation set (line 2). Remember, the target variables are tensors of dimension 2. We get the second element for each of the tensors (line 4). This means we will now have a binary target variable — 1 for STOP and 0 for GO.

# Get images and targets
images, target = next(iter(valid_loader))
images = images.to(device)
target = [int(t[1]) for t in target]

We use our model to make predictions on the validation images (line 2). Again, the output will be tensors of dimension 2. We consider the second element. If the probability is above 0.5 we predict STOP otherwise we predict GO.

# Get predictions
output=model(images)
prediction = [1 if o[1] > 0.5 else 0 for o in output]

Finally, we compare the target to the prediction using the same code we used to evaluate the benchmark. We now have an accuracy of 98.05%, a precision of 97.38% and a recall of 97.5%. A significant improvement over the benchmark! In the confusion matrix, you can see where the errors are coming from.

Confusion matrix of the model on the validation set
Figure 6: confusion matrix of the model on the validation set (source: author)

In Figure 7, we take a closer look at some of these errors. The top row gives some false positives. These are images predicted at STOP when the car should GO. Similarly, the bottom row gives false negatives.

Examples of the models false positives and false negatives on tin can object detection predictions
Figure 7: examples of prediction errors (source: author)

You may have noticed that all the obstacles are at a similar distance. When the images were labelled we used a cutoff distance. That is once the obstacle was closer than this cutoff it was labelled STOP. The above obstacles are all close to this cutoff. They could have been incorrectly labelled so the model may be “confused” when the obstacle is close to this cutoff.

Our model seems to be doing well. We can be even more certain by understanding how it makes these predictions. To do this we use SHAP. If you are new to SHAP, you may find the video below useful. Otherwise, check out my SHAP course. You can get free access if you sign up for my Newsletter 🙂

The code below calculates and displays the SHAP values for the 3 example images we saw in Figure 1. If you want more detail on how this code works, then check out the article mentioned at the end.

# Load saved model 
model = torch.load('../../models/object_detection_cnn.pth')

# Use CPU
device = torch.device('cpu')
model = model.to(device)

#Load 100 images for background
shap_loader = DataLoader(train_data, batch_size=100, shuffle=True)
background, _ = next(iter(shap_loader))
background = background.to(device)

#Create SHAP explainer
explainer = shap.DeepExplainer(model, background)

# Load test images
test_images = [Image.open(path) for path in ex_paths]
test_images = np.array(test_images)

test_input = [TRANSFORMS(img) for img in test_images]
test_input = torch.stack(test_input).to(device)

# Get SHAP values
shap_values = explainer.shap_values(test_input)

# Reshape shap values and images for plotting
shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2))
test_numpy = np.array([np.array(img) for img in test_images])

shap.image_plot(shap_numpy, test_numpy,show=False)

You can see the output in Figure 8. The first two rows are images labelled as GO and the third is labelled as STOP. We have SHAP values for each of the elements in the target tensors. The first column is for the GO prediction and the second is for the STOP prediction.

The colours are important. Blue SHAP values tell us that those pixels have decreased the predicted value. In other words, they have made it less likely that the model predicts the given label. Similarly, Red SHAP values have increased the likelihood.

SHAP values for a PyTorch image classification model
Figure 8: SHAP values for example images (source: author)

To understand this, let’s focus on the top right corner of Figure 8. In Figure 9, we have the image labelled as GO and the SHAP values from the GO prediction. You can see that the majority of the pixels are red. These have increased the value for this prediction leading to a correct GO prediction. You can also see that the pixels are clusters around the obstacle cutoff — tin can location where the label changes from GO to STOP.

Close up on the SHAP values for the GO label and GO prediction
Figure 9: SHAP values for GO prediction and GO label.

In Figure 10, we can see the SHAP values for the image labelled as STOP. The can is blue for the GO prediction and red for the STOP prediction. In other words, the model is using pixels from the can to decrease the GO value and increase the STOP value. That makes sense!

Close up on the SHAP values for the STOP label. The SHAP values are blue for the GO prediction and red for the STOP prediction.
Figure 10: SHAP values for STOP prediction

The model is not only making predictions accurately but the way it is making those predictions seems logical. However, one thing you may have noticed is that some of the background pixels are highlighted. This doesn’t make sense. Why would the background be important to a prediction? It could change when we remove objects or move to a new location.

The reason is the model has become overfitted to the training data. These objects are present in many of the images. The result is the model associates them with the STOP/GO labels. In the article below, we do a similar analysis. We discuss ways how to prevent this type of overfitting. We also spend more time explaining the SHAP code.

FOLLOW US ON GOOGLE NEWS

Read original article here

Denial of responsibility! Techno Blender is an automatic aggregator of the all world’s media. In each content, the hyperlink to the primary source is specified. All trademarks belong to their rightful owners, all materials to their authors. If you are the owner of the content and do not want us to publish your materials, please contact us by email – [email protected]. The content will be deleted within 24 hours.

Leave a comment