Attention Maps#

Attention maps are visual representations that highlight the areas of an input image that a neural network focuses on when making predictions. They are particularly useful in understanding and interpreting the decision-making process of deep learning models in both classification and segmentation tasks.

In image classification tasks or segmentation, attention maps reveal which parts of an image are most influential in determining the class label.

In this notebook we will see how to compute attention maps for a trained UNET model. The goal of this notebook is to understand how attention maps can be computed and how they can be used to interpret the decision-making process of a neural network.

[ ]:
# Some imports

import monai
import torch
from torch.utils.data import DataLoader
from monai.transforms import (
    EnsureChannelFirstd,
    AsDiscreted,
    Compose,
    LoadImaged,
    Orientationd,
    Randomizable,
    Resized,
    ScaleIntensityd,
    Spacingd,
    EnsureTyped,
    Lambda
)
import os
import tempfile
from utils.decathlon_dataset import get_decathlon_dataloader
from utils.unet import UNET
from utils.train import train

import matplotlib.pyplot as plt

[ ]:
# Load the data
root_dir = './utils/datasets'
task = "Task04_Hippocampus"

train_loader = get_decathlon_dataloader(root_dir, task, "training", batch_size=4, num_workers=2, shuffle=True)
val_loader = get_decathlon_dataloader(root_dir, task, "validation", batch_size=4, num_workers=2)

# Load the model

model = UNET(in_channels=1, out_channels=3)

# Load the model weights

model.load_state_dict(torch.load("trained_unet.pth"))

# Set the model to evaluation mode

model.eval()

Understanding Attention maps and GradCAM#

GradCAM (Gradient-weighted Class Activation Mapping) is a technique for visualizing the regions of an input image that a convolutional neural network focuses on when making predictions. Let’s break down how it works:

  1. Target Layer Selection: We choose a target layer in the network, typically one of the later convolutional layers that captures high-level features.

  2. Forward Pass: The input image is passed through the network to obtain predictions.

  3. Backward Pass: We perform backpropagation for the target class, computing gradients with respect to the target layer’s activations.

  4. Importance Weighting: The gradients are globally average pooled to obtain weights for each feature map in the target layer.

  5. Feature Map Weighting: These weights are used to scale the corresponding feature maps, emphasizing the important features for the target class.

  6. Heatmap Creation: The weighted feature maps are combined to create a heatmap, which is then normalized and ReLU-activated.

By viewing these visualizations, we can gain insights into the model’s decision-making process and identify potential biases or unexpected behaviors in our neural networks which can be particularly useful to understand the model’s behaviour on medical images.

Let’s use the model we trained in the previous notebook to make predictions on the validation set and compute the attention maps. We will use the GradCAM class to compute the attention maps, read through the utils/attention_maps.py file to understand how it works.

[ ]:
from utils.attention_maps import GradCAM, apply_cam
# Use the model to generate attention maps
target_layer = model.downs[-1].conv[-3]  # Using the last ReLU in the last downsampling block

grad_cam = GradCAM(model, target_layer)

# Find a sample image with labels
for sample_data in val_loader:
    sample_label = sample_data['label'][20:21]

    if sample_label.sum() > 10:  # Check if the label contains any positive values
        sample_image = sample_data['image'][20:21] # Add batch dimension
        break

# Generate prediction
with torch.no_grad():
    sample_prediction = model(sample_image).argmax(dim=1)

# Generate CAM for each class
for class_idx in range(3):  # Assuming 3 classes (background, anterior, posterior)
    cam = grad_cam.generate_cam(sample_image, class_idx)
    print(f"Attention map for class {class_idx}")
    apply_cam(sample_image[0], sample_label, cam, class_idx, sample_prediction)

Question#

What do you observe from the attention maps in this example?

Try to plot the attention maps for the other samples in the validation set. If possible, try to find a sample that the model made a suboptimal/wrong prediction on and see if the attention maps are able to highlight the regions that the model focused on to make the wrong prediction.

[ ]:
# Your code here