We've already covered convolutional neural networks at a deeper technical level. In this tutorial (006a) we explore various techniques that were developed to improve image classification performance, but that will also be applied in other tasks. Hence, we introduce them on a task that you are likely already familiar with so we can really focus on them.
In the next tutorial (006b), we will talk about uncertainty estimation and explain the necessary background.
This is the first tutorial where the focus will be on PyTorch, instead of NumPy. So we will no longer have to worry about implementing gradients (one never truly stops worrying about the gradients themselves though). Should you feel uncomfortable with PyTorch, don't worry, operations will be introduced as we go!
Important notes
- It is fine to consult with colleagues to solve the problems, in fact it is encouraged.
- Please turn off AI tools, we want you to memorize concepts and not just quickly breeze through problems. To turn off AI click on the gear in the top right corner. got to AI assistance -> Untick Show AI powered inline completions, Untick consented to use generative AI features, tick Hide Generative AI features
6.1 Classification with (Convolutional) Neural Network in PyTorch
If you recall from lectures 003 and 004, where we implemented neural networks in NumPy, this was a rather verbose way of doing it. You had to calculate the gradients, define the forward function. Moreover, these operations all ran on the CPU, making them very slow. Fortunately, PyTorch (and other libraries) exist that have abstracted many of the low level functions for neural networks. Moreover, PyTorch allows you to push data to the GPU and take advantage of parallellized matrix operations. Let's start off by explaining PyTorch.
PyTorch
Remember how we implemented things in NumPy? We defined the data and then we did operations on it. This is very useful, but when we are training neural network models, we need basically define operations and then pump the data through it. This is what PyTorch allows; it allows you to define a graph with a number of operations, and then pump data through it (on a device, CPU or GPU, of your choice). In addition, it allows you to scale to multiple GPU's, and multiple nodes within a cluster (if it's set up properly). Finally, PyTorch has a large ecosystem and community. It has libraries specific to vision problems, text problems, audio problems etc.
Tensors
Let's start by looking at how to create and manipulate Tensors in PyTorch, and compare them with NumPy arrays. Make sure that you have your runtime set to have a GPU! You can change this setting in the top right corner, expand the menu next to RAM/Disk. Then connect to a runtime that is not CPU. If you do not the tensors below will look very similar.
Tensors and Autograd in PyTorch
PyTorch is built around the concept of Tensors, which are multi-dimensional arrays similar to NumPy arrays. However, PyTorch Tensors have the added capability of being able to track computational graphs and automatically compute gradients, which is necessary for training neural networks.
Autograd is PyTorch's automatic differentiation engine. It records operations performed on Tensors to create a dynamic computational graph.
To enable gradient tracking for a Tensor, you set the requires_grad attribute to True. If you have a GPU (make sure it has NVIDIA Drivers and you have installed Pytorch on your computer, with the gpu version, if you are using it locally) and run the code below
import torch
import numpy as np
# Creating a PyTorch Tensor
# get current device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
pt_tensor = torch.tensor([[1, 2], [3, 4]]).to(device)
print("PyTorch Tensor:\n", pt_tensor)
# Creating a NumPy array
np_array = np.array([[1, 2], [3, 4]])
print("NumPy Array:\n", np_array)
# Basic operations are similar
print("\nPyTorch Tensor + 1:\n", pt_tensor + 1)
print("NumPy Array + 1:\n", np_array + 1)
Device: cuda
PyTorch Tensor:
tensor([[1, 2],
[3, 4]], device='cuda:0')
NumPy Array:
[[1 2]
[3 4]]
PyTorch Tensor + 1:
tensor([[2, 3],
[4, 5]], device='cuda:0')
NumPy Array + 1:
[[2 3]
[4 5]]
As you can see, we have specifically pushed the tensor to a specific device: 'cuda:0'.
# Create a tensor with requires_grad=True
x = torch.tensor(2.0, requires_grad=True)
# Define a simple function
y = x**2 + 3*x + 1
# Compute gradients
y.backward()
# Access the gradient of y with respect to x
print("\nTensor x:", x)
print("Function y:", y)
print("Gradient of y with respect to x (dy/dx):", x.grad)
# Let's verify the gradient manually: dy/dx = 2x + 3
# At x = 2.0, dy/dx = 2*(2.0) + 3 = 4.0 + 3 = 7.0
Tensor x: tensor(2., requires_grad=True) Function y: tensor(11., grad_fn=<AddBackward0>) Gradient of y with respect to x (dy/dx): tensor(7.)
In this example, y.backward() computes the gradients of y with respect to all Tensors that have requires_grad=True and are part of the computational graph leading to y. The gradients are accumulated in the .grad attribute of these Tensors.
Autograd is fundamental to training neural networks as it automates the backpropagation process, allowing us to efficiently compute the gradients needed to update model parameters.
Pushing Data through a PyTorch Computational Graph
We also mentioned that in PyTorch, you usually define the computational graph first. Then you read in your data, convert them to PyTorch Tensors, move them to your machine, and let PyTorch do the calculations. A simple example of feedforward layer (nn.Linear) without backpropagation is shown below.
import torch
import numpy as np
import torch.nn as nn
# Define a simple computational graph (a linear layer)
# This represents the operation: output = input * weight + bias
model = nn.Linear(in_features=2, out_features=1)
# Create some NumPy arrays
np_data1 = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
np_data2 = np.array([[5.0, 6.0]], dtype=np.float32)
print("NumPy Data 1:\n", np_data1)
print("NumPy Data 2:\n", np_data2)
# Convert NumPy arrays to PyTorch tensors
# PyTorch tensors are required to pass data through PyTorch models
pt_tensor1 = torch.from_numpy(np_data1)
pt_tensor2 = torch.from_numpy(np_data2)
print("\nPyTorch Tensor 1 (from NumPy):\n", pt_tensor1)
print("PyTorch Tensor 2 (from NumPy):\n", pt_tensor2)
# Pass the PyTorch tensors through the model
output1 = model(pt_tensor1)
output2 = model(pt_tensor2)
print("\nOutput from model for Tensor 1:\n", output1)
print("Output from model for Tensor 2:\n", output2)
NumPy Data 1:
[[1. 2.]
[3. 4.]]
NumPy Data 2:
[[5. 6.]]
PyTorch Tensor 1 (from NumPy):
tensor([[1., 2.],
[3., 4.]])
PyTorch Tensor 2 (from NumPy):
tensor([[5., 6.]])
Output from model for Tensor 1:
tensor([[-0.1596],
[-0.2752]], grad_fn=<AddmmBackward0>)
Output from model for Tensor 2:
tensor([[-0.3908]], grad_fn=<AddmmBackward0>)
Modularization in PyTorch: nn.Module
The modularization we did in NumPy in the previous lectures looks very similar to what PyTorch does. However, PyTorch takes it one step further with its nn.Module class. This class allows you to combine individual parts (such as convolutions, normalization, activation functions and much more!) into a module that you can reuse in other parts of your network. You can nest modules!
If you want to implement your own unique module, you have to write a class which inherits from nn.Module. Then in the constructor (__init__(self, ...)), you initialize the components that you want to use. Then you implement the forward(self, data) method, which takes all the data tensors that you use as argument(s).
Here is a simple illustration using the nn.Conv2d, nn.ReLU, and nn.BatchNorm2d. Fortunately, in terms of parameters they look exactly the same as what we implemented in NumPy. They behave very similarly too, except that we didn't push our numpy based network layers to the GPU.
#-------------------------------------------#
## Example Module: we will reuse this later #
#-------------------------------------------#
class ExampleLayer(nn.Module):
def __init__(self, in_channels: int , out_channels:int, kernel_size: int) -> None:
"""
Class constructor
HERE you define your network layers
"""
super(ExampleLayer, self).__init__()
# define the layers and operations that you will use later
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
This is where you define how the different
components (which you initialized) are connected.
:param images: input images,
shape: [batch_size, in_channels, height, width]
:return: activated features,
shape: [batch_size, out_channels, height, width]
"""
features = self.conv(images)
normalized_features = self.bn(features)
activated_features = self.relu(normalized_features)
return activated_features
#---------------------------------------------------------#
## Small Neural Network: in which example module is used ##
#---------------------------------------------------------#
class SmallNetwork(nn.Module):
def __init__(self, num_classes: int) -> None:
super(SmallNetwork, self).__init__()
"""
Class constructor
HERE you define your network layers
"""
# initialize the layers
# As you can see, we can nest nn.Modules
self.layer1 = ExampleLayer(in_channels=1, out_channels=16, kernel_size=3) # Changed in_channels to 1 for grayscale MNIST
self.layer2 = ExampleLayer(in_channels=16, out_channels=32, kernel_size=3)
# flatten the features
self.flatten = nn.Flatten()
# fully connected layer
flattened_size = 32 * 24 * 24
self.fc = nn.Linear(in_features=flattened_size, out_features=num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
forward function which does the forward pass
This is where you define how the different
components (which you initialized) are connected.
:param x: input
"""
x = self.layer1(x)
x = self.layer2(x)
x = self.flatten(x)
x = self.fc(x)
return x
Training our first Convolutional Neural Network
Now that we have defined our first neural network we can show what the rest of the training loop looks like. It will look remarkeably similar to what we have already implemented in the earlier lectures with a few key differences:
Earlier we defined backpropagation for SGD. This is not the only method that we can use to optimize it. A separate component called the Optimizer, can be used to change the method of optimization. Often picked choices are:
PyTorch stores gradients, so we need to zero the gradients after we apply them each time.
That's about it. Let's show an example of how to implement a convolutional neural network in PyTorch.
!pip install mnist_datasets
Collecting mnist_datasets Downloading mnist_datasets-0.14-py3-none-any.whl.metadata (5.3 kB) Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from mnist_datasets) (2.0.2) Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from mnist_datasets) (4.67.1) Downloading mnist_datasets-0.14-py3-none-any.whl (7.0 kB) Installing collected packages: mnist_datasets Successfully installed mnist_datasets-0.14
import sys
import os
from mnist_datasets import MNISTLoader
from typing import Callable, List
import torch
from torch.optim import Optimizer
import torch.nn as nn
# Check if we're in Colab
try:
import google.colab
IN_COLAB = True
except ImportError:
IN_COLAB = False
if IN_COLAB:
# Download the components file from GitHub
!wget https://raw.githubusercontent.com/RiaanZoetmulder/deep_learning_course/main/lecture_4/cnn_components.py
print("Downloaded cnn_components.py from GitHub")
else:
# Assume local environment has the file
print("Running locally")
# Now you can import the classes from your file
from mnist_datasets import MNISTLoader
import numpy
try:
from cnn_components import DataLoader, MNISTDataLoaderFactory
print("Classes imported successfully!")
except ImportError as e:
print('Cannot import cnn_components.py, ensure that they are in the SAME folder as the jupyter notebook!')
--2026-02-03 09:39:08-- https://raw.githubusercontent.com/RiaanZoetmulder/deep_learning_course/main/lecture_4/cnn_components.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 14523 (14K) [text/plain] Saving to: ‘cnn_components.py’ cnn_components.py 0%[ ] 0 --.-KB/s cnn_components.py 100%[===================>] 14.18K --.-KB/s in 0s 2026-02-03 09:39:08 (152 MB/s) - ‘cnn_components.py’ saved [14523/14523] Downloaded cnn_components.py from GitHub Classes imported successfully!
def train_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimizer: Optimizer, loader: MNISTLoader, device: str) -> List[float]:
"""
Train the model for a single EPOCH on the data.
Remember:
An EPOCH is defined as a single pass over the entire dataset.
An iteration is defined as a single update using a batch of data.
An epoch is therefore composed of floor(number_of_datapoints/batch_size) iterations.
Parameters
----------
model : nn.Module
The model to train.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
optimizer : Optimizer
The optimizer to use.
Returns
-------
List[float]
A list of the losses for each iteration in this training EPOCH
"""
# IMPORTANT: Set the model to training mode.
model.train()
train_losses = []
train_accuracy = []
for images, labels in loader:
# convert images and labels to pytorch Tensors and move to device
images = torch.from_numpy(images).to(device).float() # Cast to float32
labels = torch.from_numpy(labels).to(device)
# IMPORTANT: zero the gradients, this means you remove the old ones from
# memory.
optimizer.zero_grad()
# TRAIN STEP: This you should be familiar with ;)
predictions = model(images)
loss = loss_function(predictions, labels)
loss.backward()
iteration_accuracy = calculate_accuracy(labels, predictions)
train_accuracy.append(iteration_accuracy)
optimizer.step()
# move loss to cpu and store in list of train losses
train_losses.append(loss.item())
return train_losses, train_accuracy
def validate_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], loader: MNISTLoader, device: str) -> List[float]:
"""
Validate the model, you iterate over the validation set and calculate the loss.
Parameters
----------
model : nn.Module
The model to validate.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
Returns
-------
List[float]
A list of the losses for each iteration
"""
# IMPORTANT: Set the model to evaluation mode! Otherwise it starts
# calculating gradients, which we don't want it to do!
model.eval()
model.zero_grad()
val_losses = []
val_accuracy = []
for images, labels in loader:
# convert images and labels to pytorch Tensors and move to device
images = torch.from_numpy(images).to(device).float() # Cast to float32
labels = torch.from_numpy(labels).to(device)
# As you have seen before, here you just predict the labels
# calculate the losses and store them
predictions = model(images)
# calculate the accuracy
iteration_accuracy = calculate_accuracy(labels, predictions)
loss = loss_function(predictions, labels)
val_losses.append(loss.item())
val_accuracy.append(iteration_accuracy)
return val_losses, val_accuracy
def calculate_accuracy(labels: torch.Tensor, predictions: torch.Tensor) -> torch.Tensor:
"""
Calculate the accuracy of the predictions.
Parameters
----------
labels : torch.Tensor
The labels of the data.
predictions : torch.Tensor
The predictions of the model.
Returns
-------
torch.Tensor
The accuracy of the predictions.
"""
_, predicted = torch.max(predictions, 1)
_ , labels = torch.max(labels, 1)
correct = (predicted == labels).sum().item()
total = labels.size(0)
accuracy = correct / total
return accuracy
def run_training(
learning_rate: float,
batch_size: int,
num_epochs: int) -> List[List[float]]:
# get Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
# Get the Dataloaders, we will reuse the old ones for now.
# Later, when tasks become more complicated, we will switch to more advanced
# dataloaders.
loader_factory = MNISTDataLoaderFactory(batch_size = batch_size, flatten = False)
train_loader = loader_factory.get_train_dataset()
validation_loader = loader_factory.get_validation_dataset()
# define the model
model = SmallNetwork(num_classes=10).to(device) # Move model to device
# define the loss function
loss_function = nn.CrossEntropyLoss()
# define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# train the model
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = [], [], [], []
for epoch in range(num_epochs):
print(f'#---------------- START EPOCH {epoch+1} ----------------#')
train_losses, train_accuracies = train_epoch(model, loss_function, optimizer, train_loader, device)
val_losses, val_accuracies = validate_epoch(model, loss_function, validation_loader, device) # Corrected variable name
# average epoch accuracy
train_accuracy = sum(train_accuracies)/len(train_accuracies)
train_loss = sum(train_losses)/len(train_losses)
validation_accuracy = sum(val_accuracies)/len(val_accuracies) # Use val_accuracies
validation_loss = sum(val_losses)/len(val_losses) # Use val_losses
# save them for a plot later
train_epoch_accuracies.append(train_accuracy)
train_epoch_losses.append(train_loss)
val_epoch_accuracies.append(validation_accuracy)
val_epoch_losses.append(validation_loss)
# round the print statements to three decimal places
train_accuracy = round(train_accuracy, 3)
train_loss = round(train_loss, 3)
validation_accuracy = round(validation_accuracy, 3)
validation_loss = round(validation_loss, 3)
# report it out
print(f"Train Accuracy: {train_accuracy}, Train Loss: {train_loss}")
print(f"Validation Accuracy: {validation_accuracy}, Validation Loss: {validation_loss}")
print(f'#---------------- END EPOCH {epoch+1} -----------------#')
print('\n')
print('========================== FINISHED TRAINING ==========================')
print('\n')
return train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses
learning_rate = 0.001
batch_size = 64
num_epochs = 10
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = run_training(learning_rate, batch_size, num_epochs)
Device: cuda
Downloading MNIST files: 100%|██████████| 4/4 [00:00<00:00, 17604.63it/s] Downloading MNIST files: 100%|██████████| 4/4 [00:00<00:00, 32326.04it/s]
#---------------- START EPOCH 1 ----------------# Train Accuracy: 0.958, Train Loss: 0.153 Validation Accuracy: 0.976, Validation Loss: 0.081 #---------------- END EPOCH 1 -----------------# #---------------- START EPOCH 2 ----------------# Train Accuracy: 0.983, Train Loss: 0.054 Validation Accuracy: 0.982, Validation Loss: 0.059 #---------------- END EPOCH 2 -----------------# #---------------- START EPOCH 3 ----------------# Train Accuracy: 0.989, Train Loss: 0.033 Validation Accuracy: 0.982, Validation Loss: 0.06 #---------------- END EPOCH 3 -----------------# #---------------- START EPOCH 4 ----------------# Train Accuracy: 0.992, Train Loss: 0.024 Validation Accuracy: 0.986, Validation Loss: 0.045 #---------------- END EPOCH 4 -----------------# #---------------- START EPOCH 5 ----------------# Train Accuracy: 0.994, Train Loss: 0.018 Validation Accuracy: 0.982, Validation Loss: 0.065 #---------------- END EPOCH 5 -----------------# #---------------- START EPOCH 6 ----------------# Train Accuracy: 0.995, Train Loss: 0.015 Validation Accuracy: 0.984, Validation Loss: 0.065 #---------------- END EPOCH 6 -----------------# #---------------- START EPOCH 7 ----------------# Train Accuracy: 0.995, Train Loss: 0.014 Validation Accuracy: 0.982, Validation Loss: 0.064 #---------------- END EPOCH 7 -----------------# #---------------- START EPOCH 8 ----------------# Train Accuracy: 0.997, Train Loss: 0.01 Validation Accuracy: 0.985, Validation Loss: 0.06 #---------------- END EPOCH 8 -----------------# #---------------- START EPOCH 9 ----------------# Train Accuracy: 0.997, Train Loss: 0.009 Validation Accuracy: 0.983, Validation Loss: 0.075 #---------------- END EPOCH 9 -----------------# #---------------- START EPOCH 10 ----------------# Train Accuracy: 0.997, Train Loss: 0.009 Validation Accuracy: 0.985, Validation Loss: 0.063 #---------------- END EPOCH 10 -----------------# ========================== FINISHED TRAINING ==========================
# @title
import matplotlib.pyplot as plt
import seaborn as sns
epochs = range(1, num_epochs + 1) # Create a list of epoch numbers
plt.figure(figsize=(12, 5))
# Plotting Loss
plt.subplot(1, 2, 1) # 1 row, 2 columns, 1st plot
sns.lineplot(x=epochs, y=train_epoch_losses, label='Train Loss')
sns.lineplot(x=epochs, y=val_epoch_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss per Epoch')
plt.legend()
# Plotting Accuracy
plt.subplot(1, 2, 2) # 1 row, 2 columns, 2nd plot
sns.lineplot(x=epochs, y=train_epoch_accuracies, label='Train Accuracy')
sns.lineplot(x=epochs, y=val_epoch_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy per Epoch')
plt.legend()
plt.tight_layout() # Adjust layout to prevent overlapping titles/labels
plt.show()
So we see how easy it is to create a simple convolutional neural network in PyTorch. Much more is possible, but these basics will already get you quite far. Next, we will take a tour of various neural network architectures that were developed to improve performance.
6.2 Neural Network Architectures
You can already do quite a lot with regular convolutional neural networks. Since the first time they were paralellized on GPUs and joined computer vision competitions (like the ImageNet Large Scale Visual Recognition Challenge), new architectures have been developed every year. In this section we will mention some the historically interesting, and discuss some of the technologically important architectures. The goal of the exercises in this section is not to implement every architecture out there, but to give you experience implementing convolutional neural networks that are a bit more sophisticated.
Early Convolutional Neural Networks: LeNet, AlexNet, and VGG
The success of deep learning in image recognition can be significantly attributed to the development and increasing complexity of Convolutional Neural Networks (CNNs). Here are some foundational architectures that paved the way:
LeNet
One of the earliest successful CNNs, developed by Yann LeCun in the late 1990s for recognizing handwritten digits (like those in the MNIST dataset). LeNet-5, a prominent version, introduced several key concepts that are still fundamental to CNNs today:
- Convolutional Layers: Extract features by applying learnable filters to the input.
- Pooling Layers (Subsampling): Reduce the spatial dimensions of the feature maps, helping to achieve spatial invariance.
- Fully Connected Layers: Perform classification based on the extracted features.
- Activation Functions: Used non-linear activation functions (like sigmoid or tanh) after convolutional layers.
LeNet's architecture was relatively simple but demonstrated the effectiveness of combining these layers for image recognition tasks.
AlexNet
A breakthrough architecture that dramatically improved image classification performance and won the ImageNet Large Scale Visual Recognition Challenge (ILSVRC) in 2012. Developed by Alex Krizhevsky, Ilya Sutskever, and Geoffrey Hinton, AlexNet showcased the power of deeper CNNs trained on a large-scale dataset (ImageNet) with the help of GPUs for acceleration. Key innovations and features included:
- Deeper Architecture: Significantly deeper than previous CNNs.
- ReLU Activation Functions: Used Rectified Linear Units (ReLU) instead of sigmoid or tanh, which helped with training deeper networks by mitigating the vanishing gradient problem.
- Dropout: A regularization technique where random neurons are ignored during training, preventing overfitting.
- Overlapping Pooling: Used pooling regions that overlapped, which was found to improve accuracy.
- GPU Implementation: Trained on multiple GPUs due to the model's size and the dataset's scale, highlighting the importance of hardware acceleration.
AlexNet's success marked a turning point, demonstrating that deep CNNs were a viable and powerful approach for complex visual recognition tasks.
VGG
Introduced by the Visual Geometry Group at the University of Oxford in 2014, VGG networks focused on simplifying the architecture by using small (3x3) convolutional filters throughout the network. The emphasis was on demonstrating that increasing the depth of the network, with smaller filters, could lead to significant improvements in performance. Notable aspects of VGG include:
- Uniform Architecture: Consistent use of 3x3 convolutional layers followed by 2x2 max pooling layers.
- Increased Depth: Explored the impact of very deep networks (e.g., VGG16 and VGG19, referring to the number of weight layers).
- Stacking Small Filters: Showed that stacking multiple small filter convolutional layers is equivalent to a larger receptive field while having fewer parameters and more non-linearities.
VGG networks, while computationally expensive due to their depth and number of parameters, were influential in highlighting the importance of network depth and the effectiveness of using small convolutional kernels.
Residual Networks (ResNet)
Residual Networks (ResNets) revolutionized deep learning by enabling the training of extremely deep neural networks without suffering from the common pitfalls of earlier architectures. ResNet-50, for example, is a 50-layer deep convolutional network that achieves state-of-the-art performance on image classification tasks.
Why is this significant? Before ResNet, increasing network depth often led to worse performance, not better. ResNet solved this by introducing a simple yet powerful idea: skip connections.
The Vanishing Gradient Problem
When training deep networks, gradients are propagated backward through many layers. If the gradients become very small (close to zero), earlier layers learn extremely slowly. If you remember, this is called the vanishing gradient problem. This occurs because, during backpropagation, gradients are calculated by multiplying the gradients of each layer. In deep networks, if these gradients are small, multiplying many small numbers together can cause the gradient to shrink exponentially as it propagates back to the earlier layers. This makes the weights in the initial layers update very slowly, hindering the learning process.
Skip Connections
The way we implemented convolutional neural networks before was that we tried to transform the entire input. What ResNets do is they try to learn small edits to the input, which in an ideal situation are as small as possible. So you propagate the data through your network, editing it slightly along the way. This is what they call, learning the residual mapping. Because, you allow the data to propagate straight through your network, your gradient term is now composed of a normal gradient and an identity gradient path. The identity gradient path will not become small even though your normal gradient will.
By LunarLullaby - Own work, CC BY-SA 4.0, Link
Technical explanation
Going a bit deeper into technical terminology, ResNet introduces skip connections, which propagate the data straight through the network. At each step you add the learned residual mapping to input. Refer to the image above for a pictographic description (this describes two layers, not one which we do). Formally, this is described in the following way:
$$ H(x) = F(x) + x $$Where, $H(x)$ is the true mapping we want i.e. the output of your layer, and $F(x)$ is the residual we learn. This means the layer only needs to learn the difference between the input and the desired output, which is often easier.
When we put this into an actual neural network with parameters we get the input ($x$), output $y$ and a learnable transformation $F(x, W)$. In the first version of ResNet the residual function consists of the following operations: Conv → BatchNorm → ReLU → Conv → BatchNorm.
Here:
- $F(x, W)$ is the residual function (e.g., Conv → BatchNorm → ReLU → Conv → BatchNorm).
- $x$ is added element-wise to the output of $F(x, W)$.
If dimensions differ (e.g., due to stride or channel changes), we use a projection shortcut. This means that we apply a small 1x1 convolution to the input as well.
Exercises 6.1: Implementing Residual Blocks
Earlier in the notebook, you've already seen how to implement a normal convolutional neural network. In this exercises we will be implementing a simpler version of ResNet-50, called ResNet-18. Follow the hints and instructions below.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
class BasicBlock(nn.Module):
# the expansion parameter is useful for consistency with later versions
# of ResNet blocks. For this version of the BasicBlock it's not relevant
# yet. We keep it here for consistency!
# It refers to the expansion of the channels in specifically the
# Bottleneck block, where you can expand the number of channels
# by a certain "expansion" factor.
expansion = 1
def __init__(self, in_channels: int, out_channels: int, stride: int =1):
super(BasicBlock, self).__init__()
"""
ResNet Basic Block (V1)
Source: https://www.geeksforgeeks.org/deep-learning/resnet18-from-scratch-using-pytorch/
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int): Stride for the convolution.
"""
# EXERCISES 1.1
### Implement the residual function
### Conv -> batch_norm -> ReLU -? Conv ->batch_norm
### The convolutions have a kernel size of 3x3 in this case,
### The stride should be variable, in case of downsampling only the first
### block should have it, the later ones do not downsample.
self.conv1 = None
self.bn1 = None
self.relu = None
self.conv2 = None
self.bn2 = None
############## END ##############
## If the stride does not equal one, add a small channel-wise operation
## to the input x
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != self.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * out_channels)
)
def forward(self, x):
identity = x
# EXERCISES 1.1
## Do the same for the forward implementation!
out = None
############## END ##############
# see how we transform the identity if we need to expand
# the dimensions
out += self.shortcut(identity)
out = self.relu(out)
return out
Exercise 6.2: Implementing ResNet 18
Next, you will implement a smaller version of ResNet-50, called Resnet-18. The architecture is shown in the table below. Why is it called ResNet 18? Well we have an input layer, 16 ResNet layer, and a feedforward (FC Layer).
| Layer | Input Size (N, C, H, W) | Output Size (N, C, H, W) | Details |
|---|---|---|---|
| Input Layer | (N, 3, 32, 32) | (N, 3, 32, 32) | Raw CIFAR-100 image |
| Input Conv | (N, 3, 32, 32) | (N, 64, 32, 32) | 3×3 conv, 64 filters, stride 1 + BatchNorm + ReLU |
| Block 1 | (N, 64, 32, 32) | (N, 64, 32, 32) | 2 × BasicBlock (each: 2× 3×3 conv, 64 filters) |
| Block 2 | (N, 64, 32, 32) | (N, 128, 16, 16) | 2 × BasicBlock (first block uses stride 2 for downsampling) |
| Block 3 | (N, 128, 16, 16) | (N, 256, 8, 8) | 2 × BasicBlock (first block uses stride 2 for downsampling) |
| Block 4 | (N, 256, 8, 8) | (N, 512, 4, 4) | 2 × BasicBlock (first block uses stride 2 for downsampling) |
| Global AvgPool | (N, 512, 4, 4) | (N, 512, 1, 1) | Average pooling over 4×4 |
| FC Layer | (N, 512, 1, 1) | (N, 100) | Fully connected layer for CIFAR-100 |
The dataset(CIFAR-100) is a dataset with 100 classes of 32 by 32 color images (so 3 channels).
NOTE: Global Average pooling is just taking the values over the spatial dimensions of the image, so the last two dimensions in PyTorch.
NOTE: An observant reader will notice that the input conv has slightly different dimensions from the original Resnet paper. If you want to you can also implement the correct dimensions found in the paper. But unless you want to join the ImageNet competition in 2015 it won't matter all that much for this tutorial.
class ResNetInputLayer(nn.Module):
def __init__(self, in_channels: int, out_channels: int, stride: int =1):
super(ResNetInputLayer, self).__init__()
"""
ResNet Input Layer
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int): Stride for the convolution.
"""
## EXERCISE 6.2
## implement the input layer of the ResNet-18
self.conv1 = None
self.bn1 = None
self.relu = None
def forward(self, x):
## EXERCISE 6.2
## implement the forward function
## HINT: You will have to apply multiple operations to "out", to get "out"!
out = None
return out
class ResidualNetwork(nn.Module):
def __init__(self, block: nn.Module, num_blocks: List[int], in_channels: int = 3, num_classes: int=10):
"""
Generic ResNet implementation, should be adaptable to make any ResNet architecture.
But, we are only using it to make a ResNet-18
NOTE: Exercises live here!
Args:
in_channels (int): Number of input channels (3 because we are working with RGB data)
block (nn.Module): What type of ResNet block to use, we will just use the regular one that we defined before!
num_blocks (List[int]): Number of residual blocks in each layer.
num_classes (int): Number of output classes.
"""
super(ResidualNetwork, self).__init__()
self.in_channels = 64 # Correctly initialize in_channels for the first residual block
## Exercise 6.2: Define the input layer
self.model = nn.ModuleList()
first_layer = None # HINT: Use the ResNetInputlayer here!
self.model.append(first_layer)
##Exercise 6.2: Define the residual blocks
# HINT: think about by how much channels should expand, how much do you multiply the number
# of output channels with for each ResNet block?
out_channels = 64
for i in range(len(num_blocks)):
stride = 1 if i == 0 else 2
## HINT: See and make use of the _make_layer helper function!
## do your main implementation there!
## This is tricky, don't hesitate to ask for help if you get stuck!
residual_block = None
self.model.append(residual_block)
## Average pooling and classification done for you!
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block: nn.Module, out_channels: int, num_blocks: int, stride: int) -> nn.Sequential:
## Exercise 6.2: Use your Resnet Block!
## Create two layers
strides = None # Some List of strides!
layers = []
# HINT 1: If a layer is downsampling their resolution, the first stride,
# in the first (Basic)Block should be 2, the rest of the strides should be 1
# Hint 2: The first blocks output layers, become the second layers in_channels
# If you want to make it consistent, also multiply it with the blocks expansion.
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = None # Hint: Correctly update in_channels for the next block
return nn.Sequential(*layers)
def forward(self, x):
out = x
for layer in self.model:
out = layer(out)
out = self.avgpool(out)
out = torch.flatten(out, 1)
out = self.fc(out)
return out
def ResNet18(num_classes=10):
return ResidualNetwork(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_channels = 3)
net = ResNet18(num_classes=100) # Changed to 100 for CIFAR-100
print(net)
ResidualNetwork(
(model): ModuleList(
(0): ResNetInputLayer(
(conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
)
(2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
)
(3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(shortcut): Sequential()
)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=100, bias=True)
)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import Callable, List
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
# Data Loading for CIFAR-100
# These are augmentations, and will be explored later! In case you are wondering
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # CIFAR-100 mean and std
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # CIFAR-100 mean and std
])
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
def calculate_accuracy(labels: torch.Tensor, predictions: torch.Tensor) -> float:
"""
Calculate the accuracy of the predictions.
Parameters
----------
labels : torch.Tensor
The true labels.
predictions : torch.Tensor
The predicted logits from the model.
Returns
-------
float
The accuracy of the predictions as a percentage.
"""
# Get the index of the max log-probability
_, predicted = torch.max(predictions.data, 1)
total = labels.size(0)
correct = (predicted == labels).sum().item()
accuracy = correct / total
return accuracy
def train_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimizer: optim.Optimizer, loader: DataLoader, device: str) -> List[float]:
"""
Train the model for a single EPOCH on the data.
Parameters
----------
model : nn.Module
The model to train.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
optimizer : optim.Optimizer
The optimizer to use.
loader : DataLoader
The data loader for the training set.
device : str
The device to train on ('cuda' or 'cpu').
Returns
-------
Tuple[List[float], List[float]]
A tuple containing a list of losses and a list of accuracies for each iteration.
"""
model.train()
train_losses = []
train_accuracies = []
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
predictions = model(images)
loss = loss_function(predictions, labels)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
train_accuracies.append(calculate_accuracy(labels, predictions))
return train_losses, train_accuracies
def validate_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], loader: DataLoader, device: str) -> List[float]:
"""
Validate the model for a single EPOCH on the data.
Parameters
----------
model : nn.Module
The model to validate.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
loader : DataLoader
The data loader for the validation set.
device : str
The device to validate on ('cuda' or 'cpu').
Returns
-------
Tuple[List[float], List[float]]
A tuple containing a list of losses and a list of accuracies for each iteration.
"""
model.eval()
val_losses = []
val_accuracies = []
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
predictions = model(images)
loss = loss_function(predictions, labels)
val_losses.append(loss.item())
val_accuracies.append(calculate_accuracy(labels, predictions))
return val_losses, val_accuracies
def run_training(
model: nn.Module,
learning_rate: float,
batch_size: int,
num_epochs: int,
train_loader: DataLoader,
validation_loader: DataLoader,
device: str) -> List[List[float]]:
# define the loss function
loss_function = nn.CrossEntropyLoss()
# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
# train the model
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = [], [], [], []
for epoch in range(num_epochs):
print(f'#---------------- START EPOCH {epoch+1} ----------------#')
train_losses, train_accuracies = train_epoch(model, loss_function, optimizer, train_loader, device)
val_losses, val_accuracies = validate_epoch(model, loss_function, validation_loader, device)
# average epoch accuracy and loss
train_accuracy = sum(train_accuracies)/len(train_accuracies)
train_loss = sum(train_losses)/len(train_losses)
validation_accuracy = sum(val_accuracies)/len(val_accuracies)
validation_loss = sum(val_losses)/len(val_losses)
# save them for a plot later
train_epoch_accuracies.append(train_accuracy)
train_epoch_losses.append(train_loss)
val_epoch_accuracies.append(validation_accuracy)
val_epoch_losses.append(validation_loss)
# round the print statements to three decimal places
train_accuracy = round(train_accuracy, 3)
train_loss = round(train_loss, 3)
validation_accuracy = round(validation_accuracy, 3)
validation_loss = round(validation_loss, 3)
# report it out
print(f"Train Accuracy: {train_accuracy}, Train Loss: {train_loss}")
print(f"Validation Accuracy: {validation_accuracy}, Validation Loss: {validation_loss}")
print(f'#---------------- END EPOCH {epoch+1} -----------------#')
print('\n')
print('========================== FINISHED TRAINING ==========================')
print('\n')
return train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses
# Instantiate the ResNet18 model (assuming 100 classes for CIFAR-100)
resnet18_model = ResNet18(num_classes=100).to(device)
# Set hyperparameters
learning_rate = 0.001
batch_size = 128 # Using batch size from DataLoader instantiation
num_epochs = 10
# Run the training
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = run_training(
resnet18_model,
learning_rate,
batch_size,
num_epochs,
trainloader,
testloader,
device
)
Device: cuda #---------------- START EPOCH 1 ----------------# Train Accuracy: 0.134, Train Loss: 3.686 Validation Accuracy: 0.213, Validation Loss: 3.189 #---------------- END EPOCH 1 -----------------# #---------------- START EPOCH 2 ----------------# Train Accuracy: 0.293, Train Loss: 2.767 Validation Accuracy: 0.342, Validation Loss: 2.556 #---------------- END EPOCH 2 -----------------# #---------------- START EPOCH 3 ----------------# Train Accuracy: 0.419, Train Loss: 2.158 Validation Accuracy: 0.398, Validation Loss: 2.329 #---------------- END EPOCH 3 -----------------# #---------------- START EPOCH 4 ----------------# Train Accuracy: 0.514, Train Loss: 1.751 Validation Accuracy: 0.468, Validation Loss: 1.957 #---------------- END EPOCH 4 -----------------# #---------------- START EPOCH 5 ----------------# Train Accuracy: 0.595, Train Loss: 1.415 Validation Accuracy: 0.5, Validation Loss: 1.869 #---------------- END EPOCH 5 -----------------# #---------------- START EPOCH 6 ----------------# Train Accuracy: 0.674, Train Loss: 1.105 Validation Accuracy: 0.525, Validation Loss: 1.77 #---------------- END EPOCH 6 -----------------# #---------------- START EPOCH 7 ----------------# Train Accuracy: 0.756, Train Loss: 0.805 Validation Accuracy: 0.542, Validation Loss: 1.764 #---------------- END EPOCH 7 -----------------# #---------------- START EPOCH 8 ----------------# Train Accuracy: 0.838, Train Loss: 0.525 Validation Accuracy: 0.537, Validation Loss: 1.919 #---------------- END EPOCH 8 -----------------# #---------------- START EPOCH 9 ----------------# Train Accuracy: 0.911, Train Loss: 0.296 Validation Accuracy: 0.528, Validation Loss: 2.059 #---------------- END EPOCH 9 -----------------# #---------------- START EPOCH 10 ----------------# Train Accuracy: 0.949, Train Loss: 0.176 Validation Accuracy: 0.523, Validation Loss: 2.333 #---------------- END EPOCH 10 -----------------# ========================== FINISHED TRAINING ==========================
import matplotlib.pyplot as plt
import seaborn as sns
epochs = range(1, num_epochs + 1) # Create a list of epoch numbers
plt.figure(figsize=(12, 5))
# Plotting Loss
plt.subplot(1, 2, 1) # 1 row, 2 columns, 1st plot
sns.lineplot(x=epochs, y=train_epoch_losses, label='Train Loss')
sns.lineplot(x=epochs, y=val_epoch_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss per Epoch')
plt.legend()
# Plotting Accuracy
plt.subplot(1, 2, 2) # 1 row, 2 columns, 2nd plot
sns.lineplot(x=epochs, y=train_epoch_accuracies, label='Train Accuracy')
sns.lineplot(x=epochs, y=val_epoch_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy per Epoch')
plt.legend()
plt.tight_layout() # Adjust layout to prevent overlapping titles/labels
plt.show()
Inception (GoogLeNet)
The Inception architecture (don't fall asleep, it will invade your dreams if you do), introduced in GoogLeNet (named as a tribute to LeNet, by Yann LeCun and Google), aimed to improve the efficiency and performance of CNNs by allowing the network to choose the optimal kernel size at each layer. Instead of just stacking convolutional layers of the same size, Inception modules perform multiple convolutions with different kernel sizes in parallel and concatenate their results.
Technical Explanation
The core idea is to address the question: what is the best receptive field size (kernel size) at a particular layer? Smaller kernels (like 1x1 or 3x3) capture fine-grained features, while larger kernels (like 5x5) capture more global features. Inception modules utilize this by having parallel branches with 1x1, 3x3, and 5x5 convolutional layers, as well as a max pooling layer.
Here's a breakdown of a basic Inception module:
- 1x1 Convolution: Used for dimensionality reduction and adding non-linearity. This is a key innovation as it allows controlling the number of input channels to the 3x3 and 5x5 convolutions, making the module computationally cheaper.
- 3x3 Convolution: Captures local features.
- 5x5 Convolution: Captures larger-scale features.
- Max Pooling: Captures the most prominent features in a region.
The outputs of these parallel branches are then concatenated along the channel dimension.
This parallel structure allows the network to effectively capture features at different scales within the same layer, leading to better performance and more efficient use of computational resources compared to simply stacking many layers of the same type. For example, VGG had 2 orders of magnitude more parameters than Inception.
Exercise 6.3: Implementing an Inception module
We are going to implement a basic inception module (with dimensionality reduction). This officially is the Inception V1 block, there are others too. Here is the architecture:
| Step | Description |
|---|---|
| 1. 1x1 Convolution Path | Create a path that applies a 1x1 convolution to the input. |
| 2. 3x3 Convolution Path | Create a path that applies a 1x1 convolution (for dimensionality reduction) followed by a 3x3 convolution. |
| 3. 5x5 Convolution Path | Create a path that applies a 1x1 convolution (for dimensionality reduction) followed by a 5x5 convolution. |
| 4. Pooling Path | Create a path that applies a pooling operation (like max pooling) followed by a 1x1 convolution (for dimensionality adjustment). |
| 5. Apply Activation | Apply a non-linear activation function (like ReLU) after each convolutional operation within the paths. |
| 6. Combine Outputs | Pass the input through each defined path and combine the results. This is typically done by concatenating the outputs along the feature dimension. |
If you prefer pictures, no problem. Here is a diagram of the inception module that we are going to implement:
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicInception(nn.Module):
def __init__(self, in_channels: int, ch1x1: int, ch3x3red: int, ch3x3: int, ch5x5red: int, ch5x5: int, pool_proj: int) -> None:
super(BasicInception, self).__init__()
"""
Basic Inception module for a CNN with dimensionality reduction.
After each branch, the resolution of the feature maps and the batch size is the same
for each branch. They may differ in terms of their channels!
This allows concatenation along the channel dimension.
Parameters:
- in_channels: Number of input channels.
- ch1x1: Number of 1x1 convolutional filters for the 1x1 branch.
- ch3x3red: Number of 1x1 convolutional filters for dimensionality reduction before the 3x3 branch.
This is the output dimension of that layer in the 3x3 branch.
- ch3x3: Number of 3x3 convolutional filters.
- ch5x5red: Number of 1x1 convolutional filters for dimensionality reduction before the 5x5 branch.
This is the output dimension of that layer in the 3x3 branch.
- ch5x5: Number of 5x5 convolutional filters.
- pool_proj: Number of 1x1 convolutional filters after the pooling layer.
"""
# General HINT: make sure you add an activation function, after each convolutional layer.
# But not after any operation without learnable parameters!
# EXERCISE 6.3: 1x1 convolution branch
# Hint: Do not add padding here
self.branch1x1 = None
# EXERCISE 6.3: 1x1 convolution followed by 3x3 convolution branch
# Hint: you will need to add padding of 1 in the 3x3 convolution to maintain spatial dimensions
self.branch3x3 = None
# EXERCISE 6.3: 1x1 convolution followed by 5x5 convolution branch
# Hint: You need to add a padding of 2 to the 5x5 convolution branch
self.branch5x5 = None
# EXERCISE 6.3: Max pooling followed by 1x1 convolution branch
# Hint: Add padding of 1 to the max-pool, max-pool has a stride of 1 and a kernels size of 3
self.branch_pool = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Pass the input through each branch
branch1x1_out = self.branch1x1(x)
branch3x3_out = self.branch3x3(x)
branch5x5_out = self.branch5x5(x)
branch_pool_out = self.branch_pool(x)
# Concatenate the outputs along the channel dimension
outputs = [branch1x1_out, branch3x3_out, branch5x5_out, branch_pool_out]
return torch.cat(outputs, 1)
# Example usage:
# Create an instance of the Inception module
# Example parameters are from a simplified Inception module, e.g., for an early layer in GoogLeNet
# inception_module = BasicInception(in_channels=192, ch1x1=64, ch3x3red=96, ch3x3=128, ch5x5red=16, ch5x5=32, pool_proj=32)
Exercise 6.4: Implementing Simple GoogLeNet
Once, you are done implementing the Inception block, we will make a simple-GoogLeNet. Now this is not something that exists out in the real world, but for teaching purposes it exists now. The regular GoogLeNet adds supervision at multiple steps, however we will not do that.
Here is a table representing the architecture of the SimpleGoogLeNet model:
| Layer/Block | Input Size (N, C, H, W) | Output Size (N, C, H, W) | Details |
|---|---|---|---|
| Input Image | (N, 3, 32, 32) | (N, 3, 32, 32) | CIFAR-100 image |
| Conv1 + ReLU + MaxPool1 | (N, 3, 32, 32) | (N, 64, 8, 8) | 7x7 conv, 64 filters, stride 2, padding 3 ReLU 3x3 MaxPool, stride 2, padding 1 |
| Inception1 | (N, 64, 8, 8) | (N, 256, 8, 8) | BasicInception module (as defined previously) |
| Inception2 | (N, 256, 8, 8) | (N, 480, 8, 8) | BasicInception module (as defined previously) |
| Global AvgPool | (N, 480, 8, 8) | (N, 480, 1, 1) | Adaptive Average Pooling to 1x1 spatial dimensions |
| FC Layer | (N, 480, 1, 1) | (N, 100) | Fully connected layer for 100 classes |
If you are interested in seeing what the full GoogLeNet looks like, have a look at the paper, page 7. It's a rather big picture, so I have omitted it from this tutorial.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleGoogLeNet(nn.Module):
def __init__(self, num_classes=100):
super(SimpleGoogLeNet, self).__init__()
# EXERCISE 6.4: Conv1 + ReLU + MaxPool1. See table above
self.conv1 = None
self.relu1 = None
self.maxpool1 = None
# EXERCISE 6.4: First Inception block
# Parameters are just examples, typically these would be determined through experimentation or based on the original GoogLeNet architecture
self.inception1 = None
# EXERCISE 6.4: Second Inception block
# The input channels for the second inception block is the sum of the output channels of the first inception block
in_channels_inception2 = None # as can be seen in the table above
self.inception2 = None
# Global Average Pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Fully connected layer
# The input features to the FC layer is the sum of the output channels of the second inception block
fc_in_features = 128 + 192 + 96 + 64 # Sum of output channels from inception2
self.fc = nn.Linear(fc_in_features, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.maxpool1(x)
x = self.inception1(x)
x = self.inception2(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Example usage:
# Create an instance of the SimpleGoogLeNet model
# simple_googlenet = SimpleGoogLeNet(num_classes=100)
# print(simple_googlenet)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import Callable, List
import matplotlib.pyplot as plt
import seaborn as sns
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
# Data Loading for CIFAR-100
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # CIFAR-100 mean and std
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # CIFAR-100 mean and std
])
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
def calculate_accuracy(labels: torch.Tensor, predictions: torch.Tensor) -> float:
"""
Calculate the accuracy of the predictions.
Parameters
----------
labels : torch.Tensor
The true labels.
predictions : torch.Tensor
The predicted logits from the model.
Returns
-------
float
The accuracy of the predictions as a percentage.
"""
# Get the index of the max log-probability
_, predicted = torch.max(predictions.data, 1)
total = labels.size(0)
correct = (predicted == labels).sum().item()
accuracy = correct / total
return accuracy
def initialize_weights(model):
"""Initializes model weights using He initialization for ReLU activation."""
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def train_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimizer: optim.Optimizer, loader: DataLoader, device: str) -> List[float]:
"""
Train the model for a single EPOCH on the data.
Parameters
----------
model : nn.Module
The model to train.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
optimizer : optim.Optimizer
The optimizer to use.
loader : DataLoader
The data loader for the training set.
device : str
The device to train on ('cuda' or 'cpu').
Returns
-------
Tuple[List[float], List[float]]
A tuple containing a list of losses and a list of accuracies for each iteration.
"""
model.train()
train_losses = []
train_accuracies = []
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
predictions = model(images)
loss = loss_function(predictions, labels)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
train_accuracies.append(calculate_accuracy(labels, predictions))
return train_losses, train_accuracies
def validate_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], loader: DataLoader, device: str) -> List[float]:
"""
Validate the model for a single EPOCH on the data.
Parameters
----------
model : nn.Module
The model to validate.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
loader : DataLoader
The data loader for the validation set.
device : str
The device to validate on ('cuda' or 'cpu').
Returns
-------
Tuple[List[float], List[float]]
A tuple containing a list of losses and a list of accuracies for each iteration.
"""
model.eval()
val_losses = []
val_accuracies = []
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
predictions = model(images)
loss = loss_function(predictions, labels)
val_losses.append(loss.item())
val_accuracies.append(calculate_accuracy(labels, predictions))
return val_losses, val_accuracies
def run_training(
model: nn.Module,
learning_rate: float,
batch_size: int,
num_epochs: int,
train_loader: DataLoader,
validation_loader: DataLoader,
device: str) -> List[List[float]]:
# define the loss function
loss_function = nn.CrossEntropyLoss()
# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
# train the model
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = [], [], [], []
for epoch in range(num_epochs):
print(f'#---------------- START EPOCH {epoch+1} ----------------#')
train_losses, train_accuracies = train_epoch(model, loss_function, optimizer, train_loader, device)
val_losses, val_accuracies = validate_epoch(model, loss_function, validation_loader, device)
# average epoch accuracy and loss
train_accuracy = sum(train_accuracies)/len(train_accuracies)
train_loss = sum(train_losses)/len(train_losses)
validation_accuracy = sum(val_accuracies)/len(val_accuracies)
validation_loss = sum(val_losses)/len(val_losses)
# save them for a plot later
train_epoch_accuracies.append(train_accuracy)
train_epoch_losses.append(train_loss)
val_epoch_accuracies.append(validation_accuracy)
val_epoch_losses.append(validation_loss)
# round the print statements to three decimal places
train_accuracy = round(train_accuracy, 3)
train_loss = round(train_loss, 3)
validation_accuracy = round(validation_accuracy, 3)
validation_loss = round(validation_loss, 3)
# report it out
print(f"Train Accuracy: {train_accuracy}, Train Loss: {train_loss}")
print(f"Validation Accuracy: {validation_accuracy}, Validation Loss: {validation_loss}")
print(f'#---------------- END EPOCH {epoch+1} -----------------#')
print('\n')
print('========================== FINISHED TRAINING ==========================')
print('\n')
return train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses
# Instantiate the SimpleGoogLeNet model
simple_googlenet_model = SimpleGoogLeNet(num_classes=100).to(device)
# Initialize weights (optional but recommended)
initialize_weights(simple_googlenet_model)
# Set hyperparameters
learning_rate = 0.001
batch_size = 128 # Using batch size from DataLoader instantiation
num_epochs = 10
# Run the training
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = run_training(
simple_googlenet_model,
learning_rate,
batch_size,
num_epochs,
trainloader,
testloader,
device
)
# Plotting the results
epochs = range(1, num_epochs + 1)
Device: cuda #---------------- START EPOCH 1 ----------------# Train Accuracy: 0.093, Train Loss: 3.908 Validation Accuracy: 0.153, Validation Loss: 3.557 #---------------- END EPOCH 1 -----------------# #---------------- START EPOCH 2 ----------------# Train Accuracy: 0.199, Train Loss: 3.295 Validation Accuracy: 0.234, Validation Loss: 3.114 #---------------- END EPOCH 2 -----------------# #---------------- START EPOCH 3 ----------------# Train Accuracy: 0.264, Train Loss: 2.955 Validation Accuracy: 0.287, Validation Loss: 2.907 #---------------- END EPOCH 3 -----------------# #---------------- START EPOCH 4 ----------------# Train Accuracy: 0.309, Train Loss: 2.729 Validation Accuracy: 0.33, Validation Loss: 2.656 #---------------- END EPOCH 4 -----------------# #---------------- START EPOCH 5 ----------------# Train Accuracy: 0.345, Train Loss: 2.554 Validation Accuracy: 0.35, Validation Loss: 2.571 #---------------- END EPOCH 5 -----------------# #---------------- START EPOCH 6 ----------------# Train Accuracy: 0.371, Train Loss: 2.421 Validation Accuracy: 0.369, Validation Loss: 2.485 #---------------- END EPOCH 6 -----------------# #---------------- START EPOCH 7 ----------------# Train Accuracy: 0.397, Train Loss: 2.311 Validation Accuracy: 0.397, Validation Loss: 2.349 #---------------- END EPOCH 7 -----------------# #---------------- START EPOCH 8 ----------------# Train Accuracy: 0.416, Train Loss: 2.22 Validation Accuracy: 0.41, Validation Loss: 2.309 #---------------- END EPOCH 8 -----------------# #---------------- START EPOCH 9 ----------------# Train Accuracy: 0.437, Train Loss: 2.118 Validation Accuracy: 0.429, Validation Loss: 2.246 #---------------- END EPOCH 9 -----------------# #---------------- START EPOCH 10 ----------------# Train Accuracy: 0.455, Train Loss: 2.039 Validation Accuracy: 0.429, Validation Loss: 2.229 #---------------- END EPOCH 10 -----------------# ========================== FINISHED TRAINING ==========================
plt.figure(figsize=(12, 5))
# Plotting Loss
plt.subplot(1, 2, 1)
sns.lineplot(x=epochs, y=train_epoch_losses, label='Train Loss')
sns.lineplot(x=epochs, y=val_epoch_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss per Epoch')
plt.legend()
# Plotting Accuracy
plt.subplot(1, 2, 2)
sns.lineplot(x=epochs, y=train_epoch_accuracies, label='Train Accuracy')
sns.lineplot(x=epochs, y=val_epoch_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy per Epoch')
plt.legend()
plt.tight_layout()
plt.show()
MobileNet/EfficientNet
Memory is often a constraint when deploying deep learning algorithms on devices in the field. MobileNets (or EfficientNets) are a family of efficient convolutional neural networks designed specifically for mobile and embedded vision applications where computational resources are limited. Introduced by Google, the core idea behind MobileNets is to significantly reduce the number of parameters and computations compared to traditional CNNs while maintaining reasonable accuracy. They achieve this primarily through the use of depthwise separable convolutions.
Conceptual Understanding: Thinking about Filtering and Combining
Imagine a standard convolutional layer as doing two things at once:
- Filtering: Applying learned patterns (filters) to the input image to detect features like edges, corners, or textures.
- Combining: Mixing the information from different input channels (like Red, Green, and Blue in a color image) to create new, higher-level features in the output channels.
A standard convolution does this filtering and combining in one go.
MobileNets propose to separate these two operations to make it more efficient. Instead of a single, complex step, they break it down:
- Depthwise Filtering: First, they apply a separate filter to each input channel individually. Think of it like applying a red-edge-detector only to the red channel, a green-edge-detector only to the green channel, and a blue-edge-detector only to the blue channel. This is much cheaper because each filter only looks at one channel at a time.
- Pointwise Combining: After filtering each channel independently, they use a simple 1x1 convolution to combine the outputs of these filtered channels. This step uses $C_{out}$ number of $ C_{in} \times 1 \times 1$ filters. Each of these filters convolves across all $C_{in}$ filtered channels at a single spatial location to create one of the $C_{out}$ output features. This step is also relatively cheap because the kernel size is just 1x1.
By separating the filtering (depthwise) from the combining (pointwise), MobileNets drastically reduce the amount of computation needed compared to a standard convolution that does both simultaneously.
Technical Details: Depthwise Separable Convolutions
A standard 2D convolution takes an input feature map of size $(N, C_{in}, H_{in}, W_{in})$ and applies $C_{out}$ kernels of size $(C_{out}, C_{in}, K, K)$ to produce an output feature map of size $(N, C_{out}, H_{out}, W_{out})$. We have seen this in the previous lecture. The computational cost (number of multiplications and additions) is approximately:
$N \times H_{out} \times W_{out} \times K \times K \times C_{in} \times C_{out}$
Brief Explanation of above formula: We have a batch of images ($N$) for which we all do the computation. We move the convolutional kernel (which is $K \times K \times C_{in} \times C_{out} $), over the input ($H_{out} \times W_{out}$) times. Where the size of $H_{out}$ and $W_{out}$ depends on the stride and padding of the operation and the size of the input.
So what does a depthwise separable convolution replace this with? As mentioned above, it's replaced with 2 layers:
Depthwise Convolution: This operation is a convolution with
groupsset to $C_{in}$. It applies $C_{in}$ filters of size $(C_{in}, 1, K, K)$ (in terms of weight tensor shape) to the input feature map $(N, C_{in}, H_{in}, W_{in})$. Each filter is applied to a single input channel. The output feature map is of size $(N, C_{in}, H_{out}, W_{out})$. The computational cost is approximately:$N \times H_{out} \times W_{out} \times K \times K \times C_{in}$
Pointwise Convolution: This is a standard 1x1 convolution that takes the output of the depthwise convolution $(N, C_{in}, H_{out}, W_{out})$ and applies $C_{out}$ filters of size $(C_{out}, C_{in}, 1, 1)$ to produce an output feature map of size $(N, C_{out}, H_{out}, W_{out})$. The computational cost is approximately:
$N \times H_{out} \times W_{out} \times 1 \times 1 \times C_{in} \times C_{out}$
The total cost of a depthwise separable convolution is the sum of the costs of the depthwise and pointwise convolutions:
$(N \times H_{out} \times W_{out} \times K \times K \times C_{in}) + (N \times H_{out} \times W_{out} \times C_{in} \times C_{out})$
Comparing this to the cost of a standard convolution, the reduction in computation is significant. The ratio of the cost of depthwise separable convolution to standard convolution is approximately:
$\frac{(N \times H_{out} \times W_{out} \times K \times K \times C_{in}) + (N \times H_{out} \times W_{out} \times C_{in} \times C_{out})}{N \times H_{out} \times W_{out} \times K \times K \times C_{in} \times C_{out}} = \frac{K \times K + C_{out}}{K \times K \times C_{out}} = \frac{1}{C_{out}} + \frac{1}{K \times K}$
For a 3x3 kernel and a large number of output channels, this ratio is roughly $1/9$, meaning a depthwise separable convolution is about 9 times less computationally expensive than a standard convolution.
Slimming MobileNet Down
MobileNet architectures are built by stacking these depthwise separable convolutional layers, typically followed by batch normalization and ReLU activation. They also include standard layers like pooling and a final fully connected layer for classification.
MobileNets offer flexibility through two hyperparameters that allow engineers to easily tune the model for specific performance requirements:
- Width Multiplier ($\alpha$): A factor between 0 and 1 that scales the number of channels in every layer. A lower $\alpha$ results in a "thinner" network with fewer computations and parameters.
- Resolution Multiplier ($\rho$): A factor between 0 and 1 that scales the input image resolution. Using a lower resolution reduces the computational cost throughout the network.
By adjusting these multipliers, you can easily create a spectrum of MobileNet models with different trade-offs between accuracy and efficiency, allowing you to find the best fit for your target hardware and application.
In summary, MobileNets provide a practical and efficient approach to building deep learning models for resource-constrained environments by intelligently separating the core operations of convolution.
Exercise 6.5: Implementing Depth-wise Separable Convolutions
Now that we know what a depth-wise separable convolution is. We are going to implement one.
import torch.nn as nn
import torch
class DepthWiseSeparableConv2D(nn.Module):
def __init__(self, in_channels: int , out_channels: int , kernel_size: int = 3, stride: int=1, padding: int=1, alpha: float=1.0, bias: bool=False) -> None:
super(DepthWiseSeparableConv2D, self).__init__()
"""
Depth Wise Separable convolution from MobileNet V1
Hint: The resolution multiplier (rho) is implemented during dataloading,
No need to worry about it here.
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
kernel_size (int): Size of the convolutional kernel (default 3)
stride (int): Stride of the convolution (default 1)
padding (int): Padding of the convolution (default 1 for 3x3 kernel with stride 1)
alpha (float): Width multiplier (default 1.0)
bias (bool): Whether to use bias (default False)
"""
# These are the convolutions that operate on each channel separately
# Contrasted to normal convolutions, this does not mix channels
# A normal convolution if it moves over color image will mix the values
# from R, G, and B. This one keeps them separate!
# Exercise 6.5: Implement the convolution which operates on each channel individually (The depthwise separable convolution)
# Hint: Setting the groups parameter equal to in_channels makes it depthwise separable
self.depthwise_convolution = None
# Now we want to combine the different channels before we move on to the next layer
# Exercise 6.5: Now implement the pointwise convolution.
# HINT: Here you are reducing your channels
reduced_channels = None
self.pointwise_convolution = None
# other relevant operations such as batch norm and Relu:
self.model = nn.Sequential(
self.depthwise_convolution,
nn.BatchNorm2d(in_channels), # Batch norm after depthwise
nn.ReLU(),
self.pointwise_convolution,
nn.BatchNorm2d(reduced_channels), # Batch norm after pointwise
nn.ReLU()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
# Example usage (assuming input channels = 3, output channels= 64, = alpha 1.0)
# dws_conv = DepthWiseSeparableConv2D(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, alpha=1.0)
# We will use this later, when we implement a TinyMobileNet!
Exercise 6.6: Implementing TinyMobileNet
You will be implementing a small version of MobileNet called "TinyMobileNet" with 3 convolutional layers, and 1 feedforward layer incorporating the Resolution Multiplier ($\rho$) and Width Multiplier ($\alpha$):
This architecture uses an initial standard convolutional layer, followed by a few Depthwise Separable Convolutional (DSC) layers, Global Average Pooling, and a final Fully Connected layer.
We will assume a resolution multiplier $\rho = 0.75$ and a width multiplier $\alpha = 0.5$ for demonstration purposes. This means the input image resolution will be scaled down from 32x32 to approximately 24x24 ($\lfloor 32 \times 0.75 \rfloor = 24$), and the number of channels in each layer will be scaled by 0.8.
| Layer/Block | Input Size (N, C, H, W) | Output Size (N, C, H, W) | Details |
|---|---|---|---|
| Input Image (Scaled) | (N, 3, 24, 24) | (N, 3, 24, 24) | Assuming CIFAR-100 input scaled by $\rho = 0.75$ |
| Initial Conv + ReLU | (N, 3, 24, 24) | (N, $ 16 \times $, 24, 24) | Standard 3x3 conv, $16$ filters, stride 1, padding 1 + ReLU |
| DSC Layer 1 | (N, $\lfloor 16 \times \alpha \rfloor$, 24, 24) | (N, $\lfloor 32 \times \alpha \rfloor$, 12, 12) | Depthwise Separable Conv (using DepthWiseSeparableConv2D): - Depthwise: 3x3, stride 2, padding 1 - Pointwise: 1x1, stride 1, padding 0 Output Channels: $\lfloor 32 \times \alpha \rfloor$ |
| DSC Layer 2 | (N, $\lfloor 32 \times \alpha \rfloor$, 12, 12) | (N, $\lfloor 64 \times \alpha \rfloor$, 6, 6) | Depthwise Separable Conv (using DepthWiseSeparableConv2D): - Depthwise: 3x3, stride 2, padding 1 - Pointwise: 1x1, stride 1, padding 0 Output Channels: $\lfloor 64 \times \alpha \rfloor$ |
| Global AvgPool | (N, $\lfloor 64 \times \alpha \rfloor$, 6, 6) | (N, $\lfloor 64 \times \alpha \rfloor$, 1, 1) | Adaptive Average Pooling to 1x1 spatial dimensions |
| FC Layer | (N, $\lfloor 64 \times \alpha \rfloor$, 1, 1) | (N, 100) | Fully connected layer for 100 classes |
Reuse the depthwise separable convolution that you have implemented above! Once you are done, test your network by training it.
import torch
import torch.nn as nn
import torch.nn.functional as F
class TinyMobileNet(nn.Module):
def __init__(self, num_classes=100, alpha: int = 0.8): # Removed alpha parameter
super(TinyMobileNet, self).__init__()
self.alpha = alpha
# EXERCISES 6.6: Implement the layers of tinyMobileNet!
# dsc = depthwise separable convolution
# Hint, the input layer!
# Initial Conv + Batch Norm + ReLU
# Input channels: 3 (RGB)
# Output channels: 16
base_out_channels = None
self.initial_conv = nn.Sequential(
nn.Conv2d(3, base_out_channels, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(base_out_channels),
nn.ReLU(inplace=True)
)
# Hint: The first layer with depthwise separable convolutions
# Input channels: 16
# Output channels: 32
# Stride 2 for spatial downsampling
second_layer_out_channels = None # Check how to calculate!
self.dsc1 = None # Use your depthwise separable convolution block here!
# Hint: The second layer with depthwise separable convolutions
# Input channels: 32
# Output channels: 64
# Stride 2 for spatial downsampling
third_layer_out_channels = None # Check how to calculate!
self.dsc2 = None # Use your depthwise separable convolution block here!
# Global Average Pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Fully Connected Layer
# Input features: 64
# Output features: num_classes
self.fc = nn.Linear(third_layer_out_channels, num_classes)
def forward(self, x):
# Note: Resolution multiplier is typically applied to the input data loading,
# not within the model's forward pass. Assuming input 'x' is already scaled.
x = self.initial_conv(x)
x = self.dsc1(x)
x = self.dsc2(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Example usage (assuming alpha=0.5 and input resolution is already scaled to 24x24):
tiny_mobilenet = TinyMobileNet(num_classes=100, alpha = 0.8) # Removed alpha
print(tiny_mobilenet)
# Create a dummy input tensor (batch_size, channels, height, width) with scaled resolution
# dummy_input = torch.randn(1, 3, 24, 24)
# output = tiny_mobilenet(dummy_input)
# print(output.shape) # Expected output shape (1, 100)
TinyMobileNet(
(initial_conv): Sequential(
(0): Conv2d(3, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(dsc1): DepthWiseSeparableConv2D(
(depthwise_convolution): Conv2d(12, 12, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=12, bias=False)
(pointwise_convolution): Conv2d(12, 25, kernel_size=(1, 1), stride=(1, 1), bias=False)
(model): Sequential(
(0): Conv2d(12, 12, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=12, bias=False)
(1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(12, 25, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
)
(dsc2): DepthWiseSeparableConv2D(
(depthwise_convolution): Conv2d(25, 25, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=25, bias=False)
(pointwise_convolution): Conv2d(25, 51, kernel_size=(1, 1), stride=(1, 1), bias=False)
(model): Sequential(
(0): Conv2d(25, 25, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=25, bias=False)
(1): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(25, 51, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(51, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=51, out_features=100, bias=True)
)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import Callable, List
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
def calculate_accuracy(labels: torch.Tensor, predictions: torch.Tensor) -> float:
"""
Calculate the accuracy of the predictions.
Parameters
----------
labels : torch.Tensor
The true labels.
predictions : torch.Tensor
The predicted logits from the model.
Returns
-------
float
The accuracy of the predictions as a percentage.
"""
# Get the index of the max log-probability
_, predicted = torch.max(predictions.data, 1)
total = labels.size(0)
correct = (predicted == labels).sum().item()
accuracy = correct / total
return accuracy
def initialize_weights(model):
"""Initializes model weights using He initialization for ReLU activation."""
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def train_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimizer: optim.Optimizer, loader: DataLoader, device: str) -> List[float]:
"""
Train the model for a single EPOCH on the data.
Parameters
----------
model : nn.Module
The model to train.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
optimizer : optim.Optimizer
The optimizer to use.
loader : DataLoader
The data loader for the training set.
device : str
The device to train on ('cuda' or 'cpu').
Returns
-------
Tuple[List[float], List[float]]
A tuple containing a list of losses and a list of accuracies for each iteration.
"""
model.train()
train_losses = []
train_accuracies = []
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
predictions = model(images)
loss = loss_function(predictions, labels)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
train_accuracies.append(calculate_accuracy(labels, predictions))
return train_losses, train_accuracies
def validate_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], loader: DataLoader, device: str) -> List[float]:
"""
Validate the model for a single EPOCH on the data.
Parameters
----------
model : nn.Module
The model to validate.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
loader : DataLoader
The data loader for the validation set.
device : str
The device to validate on ('cuda' or 'cpu').
Returns
-------
Tuple[List[float], List[float]]
A tuple containing a list of losses and a list of accuracies for each iteration.
"""
model.eval()
val_losses = []
val_accuracies = []
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
predictions = model(images)
loss = loss_function(predictions, labels)
val_losses.append(loss.item())
val_accuracies.append(calculate_accuracy(labels, predictions))
return val_losses, val_accuracies
def run_training(
model: nn.Module,
learning_rate: float,
batch_size: int,
num_epochs: int,
rho: float, # Added rho parameter
num_classes: int = 100 # Added num_classes for DataLoader
) -> List[List[float]]:
# Data Loading for CIFAR-100 with Resolution Multiplier
# Calculate the scaled image size
scaled_size = max(1, int(32 * rho))
transform_train = transforms.Compose([
transforms.Resize((scaled_size, scaled_size)), # Apply resolution scaling
transforms.RandomCrop(scaled_size, padding=int(scaled_size * 0.125)), # Adjust padding based on scaled size
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # CIFAR-100 mean and std
])
transform_test = transforms.Compose([
transforms.Resize((scaled_size, scaled_size)), # Apply resolution scaling
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # CIFAR-100 mean and std
])
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
# define the loss function
loss_function = nn.CrossEntropyLoss()
# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
# train the model
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = [], [], [], []
for epoch in range(num_epochs):
print(f'#---------------- START EPOCH {epoch+1} ----------------#')
train_losses, train_accuracies = train_epoch(model, loss_function, optimizer, train_loader, device)
val_losses, val_accuracies = validate_epoch(model, loss_function, test_loader, device) # Use testloader for validation
# average epoch accuracy and loss
train_accuracy = sum(train_accuracies)/len(train_accuracies)
train_loss = sum(train_losses)/len(train_losses)
validation_accuracy = sum(val_accuracies)/len(val_accuracies)
validation_loss = sum(val_losses)/len(val_losses)
# save them for a plot later
train_epoch_accuracies.append(train_accuracy)
train_epoch_losses.append(train_loss)
val_epoch_accuracies.append(validation_accuracy)
val_epoch_losses.append(validation_loss)
# round the print statements to three decimal places
train_accuracy = round(train_accuracy, 3)
train_loss = round(train_loss, 3)
validation_accuracy = round(validation_accuracy, 3)
validation_loss = round(validation_loss, 3)
# report it out
print(f"Train Accuracy: {train_accuracy}, Train Loss: {train_loss}")
print(f"Validation Accuracy: {validation_accuracy}, Validation Loss: {validation_loss}")
print(f'#---------------- END EPOCH {epoch+1} -----------------#')
print('\n')
print('========================== FINISHED TRAINING ==========================')
print('\n')
return train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses
# Example Usage (assuming TinyMobileNet is defined in a previous cell)
# Instantiate the TinyMobileNet model with a specific alpha
tiny_mobilenet_model = TinyMobileNet(num_classes=100, alpha=0.8).to(device)
# Initialize weights (optional but recommended)
initialize_weights(tiny_mobilenet_model)
# Set hyperparameters including rho
learning_rate = 0.001
batch_size = 128
num_epochs = 10
rho = 0.75 # Example resolution multiplier
# Run the training
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = run_training(
tiny_mobilenet_model,
learning_rate,
batch_size,
num_epochs,
rho
)
Device: cuda #---------------- START EPOCH 1 ----------------# Train Accuracy: 0.054, Train Loss: 4.278 Validation Accuracy: 0.081, Validation Loss: 4.001 #---------------- END EPOCH 1 -----------------# #---------------- START EPOCH 2 ----------------# Train Accuracy: 0.098, Train Loss: 3.905 Validation Accuracy: 0.107, Validation Loss: 3.814 #---------------- END EPOCH 2 -----------------# #---------------- START EPOCH 3 ----------------# Train Accuracy: 0.118, Train Loss: 3.755 Validation Accuracy: 0.121, Validation Loss: 3.699 #---------------- END EPOCH 3 -----------------# #---------------- START EPOCH 4 ----------------# Train Accuracy: 0.133, Train Loss: 3.668 Validation Accuracy: 0.14, Validation Loss: 3.621 #---------------- END EPOCH 4 -----------------# #---------------- START EPOCH 5 ----------------# Train Accuracy: 0.146, Train Loss: 3.604 Validation Accuracy: 0.151, Validation Loss: 3.56 #---------------- END EPOCH 5 -----------------# #---------------- START EPOCH 6 ----------------# Train Accuracy: 0.156, Train Loss: 3.546 Validation Accuracy: 0.161, Validation Loss: 3.514 #---------------- END EPOCH 6 -----------------# #---------------- START EPOCH 7 ----------------# Train Accuracy: 0.167, Train Loss: 3.5 Validation Accuracy: 0.174, Validation Loss: 3.466 #---------------- END EPOCH 7 -----------------# #---------------- START EPOCH 8 ----------------# Train Accuracy: 0.173, Train Loss: 3.462 Validation Accuracy: 0.171, Validation Loss: 3.446 #---------------- END EPOCH 8 -----------------# #---------------- START EPOCH 9 ----------------# Train Accuracy: 0.178, Train Loss: 3.428 Validation Accuracy: 0.181, Validation Loss: 3.421 #---------------- END EPOCH 9 -----------------# #---------------- START EPOCH 10 ----------------# Train Accuracy: 0.184, Train Loss: 3.399 Validation Accuracy: 0.182, Validation Loss: 3.374 #---------------- END EPOCH 10 -----------------# ========================== FINISHED TRAINING ==========================
plt.figure(figsize=(12, 5))
# Plotting Loss
plt.subplot(1, 2, 1)
sns.lineplot(x=epochs, y=train_epoch_losses, label='Train Loss')
sns.lineplot(x=epochs, y=val_epoch_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss per Epoch')
plt.legend()
# Plotting Accuracy
plt.subplot(1, 2, 2)
sns.lineplot(x=epochs, y=train_epoch_accuracies, label='Train Accuracy')
sns.lineplot(x=epochs, y=val_epoch_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy per Epoch')
plt.legend()
plt.tight_layout()
plt.show()
6.3 Attention Mechanisms in CNNs
In the context of neural networks, attention mechanisms are designed to allow the network to focus on the most relevant parts of the input data when making a prediction. For image data processed by CNNs, this can mean focusing on specific spatial locations or, as is the case with Squeeze-and-Excitation and Gather-Excite blocks, focusing on the most important feature channels.
Think of it like this: when you look at a picture, your brain doesn't process every single pixel with equal importance. It pays more attention to the areas that are most relevant to understanding the scene. Similarly, attention mechanisms in CNNs try to mimic this by dynamically adjusting the importance of different features.
Channel Attention
Channel attention focuses on what features are most important. After a convolutional layer, a CNN produces a set of feature maps, where each map corresponds to a specific learned feature (e.g., detecting edges, corners, or more complex patterns). Channel attention mechanisms learn to weigh these feature maps differently, emphasizing the channels that are more informative for the task at hand and suppressing less useful ones.
For the purposes of this tutorial, we will focus only on channel attention as an introduction. However, there are many more attention methods that are used in CNNs and other architectures.
Squeeze-and-Excitation (SE) Blocks
Squeeze-and-Excitation blocks are a simple yet effective way to implement channel attention. They were introduced to improve the representational capacity of networks by enabling them to perform dynamic channel-wise feature recalibration. An SE block typically consists of two main steps:
Squeeze: This step aggregates the spatial information across each feature map. This is usually done using Global Average Pooling (GAP), which computes the average value for each channel over its entire spatial extent. This results in a channel descriptor that summarizes the information in each feature map. If the input feature map is of size (N, C, H, W), the output of the squeeze step is (N, C, 1, 1).
Excitation: This step takes the channel descriptor from the squeeze step and learns a set of weights, one for each channel. These weights represent the importance of each channel. This is typically done using two fully connected layers (a bottleneck structure with a reduction ratio to reduce complexity) and activation functions (like ReLU and Sigmoid). The sigmoid activation ensures the learned weights are between 0 and 1.
The output of the excitation step is a set of channel-wise weights (N, C, 1, 1). These weights are then multiplied element-wise with the original input feature map (N, C, H, W). This multiplication operation scales the importance of each channel: channels with higher weights are amplified, while those with lower weights are suppressed.
SE blocks are lightweight and can be easily integrated into existing CNN architectures (like ResNet or Inception) to boost performance with minimal computational overhead. You just stick them in after a convolutional layer of your choosing.
Exercise 6.7: Implementing Squeeze Excite
We are going to implement a Squeeze Excitation block. The process is quite straightforward: Use adaptive average pooling to average the last 2 channels of the tensor. Then remove those last two dimensions by squeezing them!
Pull them through the excitation part of the network: nn.Linear -> nn.ReLU -> nn.Linear -> nn.Sigmoid. Then reshape them to the size and multiply them element wise with the incoming tensor!
import torch.nn as nn
import torch
class SqueezeExciteBlock(nn.Module):
def __init__(self, in_channels: int, reduction_ratio: int=16) -> None:
super(SqueezeExciteBlock, self).__init__()
# Exercise 6.7: Squeeze operation.
# Remember adaptive average pooling?
self.squeeze = None # (N, C, 1, 1)
# Exercise 6.7: Excitation (run it through linear layers)
# See above for exact dimensions. You need to map it down and up i.e. in channels -> in_channels/reduction ratio -> in_channels.
self.excitation = None # Hint use nn.Sequential to keep everything tidy!
def forward(self, x: torch.Tensor) -> torch.Tensor:
feats = self.squeeze(x)
feats = feats.view(feats.size(0), -1) # (N, C)
feats = self.excitation(feats)
# channel wise muliplicaiton
x = x * feats.unsqueeze(-1).unsqueeze(-1)
return x
# Define parameters for the test
batch_size = 4
in_channels = 64
height = 32
width = 32
reduction_ratio = 16
# Create a dummy input tensor
# Shape: (batch_size, in_channels, height, width)
dummy_input = torch.randn(batch_size, in_channels, height, width)
print(f"Dummy input shape: {dummy_input.shape}")
# Instantiate the SqueezeExciteBlock
se_block = SqueezeExciteBlock(in_channels=in_channels, reduction_ratio=reduction_ratio)
print("SqueezeExciteBlock instantiated.")
# Pass the dummy input through the SE block
output = se_block(dummy_input)
# Print the output shape
print(f"Output shape: {output.shape}")
# Verify that the output shape is the same as the input shape
assert output.shape == dummy_input.shape
print("Test passed: Output shape matches input shape.")
Dummy input shape: torch.Size([4, 64, 32, 32]) SqueezeExciteBlock instantiated. Output shape: torch.Size([4, 64, 32, 32]) Test passed: Output shape matches input shape.
Gather-Excite (GE) Blocks
Gather-Excite blocks are a more recent development in channel attention, aiming to address some limitations of SE blocks, particularly in capturing richer spatial information. GE blocks introduce a "Gather" step before the "Excitation":
Gather: Instead of just global average pooling, the gather step uses a convolutional layer to aggregate information spatially. A convolutional layer (e.g., a 3x3 or 5x5 convolution) is applied to the input feature map. Unlike the convolutions in the main network path which are designed to extract features, here the convolution acts as a way to aggregate information from a local spatial region within each channel.This allows the block to capture local spatial context within each channel before recalibrating the channel weights. This step can use various kernel sizes and strides.
Excitation: Similar to SE blocks, the excitation step takes the output of the gather step and learns channel-wise weights. This is again typically done with fully connected layers and activation functions.
The channel-wise weights learned in the excitation step are then used to re-scale the original input feature map, similar to SE blocks.
The key difference lies in the gather step, which, by using convolution instead of just global average pooling, can potentially capture more nuanced spatial relationships when determining channel importance.
6.4 Loss Functions for classification (with Class imbalance)
You've already seen the cross-entropy loss. This works great when classes are roughly balanced. In the tutorial before you've already been introduced to the class imbalance problem. This problem will haunt us for a while but there are things we can do to combat it's deleterious effects. One of which is selecting a loss function that refocusses the network on putting a little more weight on getting the rare examples right.
Many loss functions have been invented, but you should probably start with these.
Weighted Cross-Entropy Loss
Class imbalance often causes models to favor majority classes because standard cross-entropy treats all classes equally. Weighted Cross-Entropy addresses this by assigning higher weights to minority classes, ensuring their errors contribute more to the loss. As a reminder this is the unweighted multi-class cross entropy:
$$ \mathcal{L} = - \sum_{i=1}^{N} \sum_{c=1}^{C} y_{i,c} \log(\hat{p}_{i,c}) \quad \quad \text{1. regular cross entropy} $$For the multi-class weighted cross-entropy, instead of using $y_{i,c}$, which in practise is a one-hot vector where a 1 indicates that a datapoint belongs to a certain class. We replace that 1, with a weight ($w_{i,c}$) that we set. For a sample with true class $y$ and predicted probabilities $p$:
$$ \mathcal{L} = - \sum_{i=1}^{N} \sum_{c=1}^{C} w_{i, c} \log(\hat{p}_{i,c}) \quad \quad \text{1. weighted cross entropy} $$
where:
- ( $\hat{p}_{i, c}$ ) = predicted probability for the true class
- ( $w_{i, c}$ ) = weight for class ( $y$ )
But how do you determine the weights? There are several strategies for that, one of which is the inverse frequency strategy. You basically calculate the counts of different classes ($n_c$) and take the inverse (take $\frac{1}{n_c}$) to get the weight for the class.
Another strategy is called the normalized inverse frequency. If you use this strategy you keep the weights on a similar scale, by dividing the total samples in the dataset, by the number of classes multiplied by $n_{c}$. A final method of setting the weights is by manual tuning, this is recommended if for domain specific reasons one class really needs to be weighted more heavily.
Here is an example in NumPy:
## Different methods of calculating the weights for weighted cross entropy.
import numpy as np
# Sample
def sample_multinomial(num_data: int, class_probs: list[int]) -> np.ndarray:
assert sum(class_probs) == 1, "Class probabilities must sum to 1"
return np.random.multinomial(num_data, class_probs, size = 1)
probabilities = [0.2, 0.5, 0.25, 0.05]
num_data = 5000
class_counts = sample_multinomial(num_data, probabilities)
inv_frequency = 1./ class_counts
norm_inv_frequency = num_data / (len(class_counts) * class_counts)
print('Example where class 4 is the minority class')
print('INVERSE FREQUENCY: As you can see below, the class weight with the fewest examples has a larger weight')
print(inv_frequency)
print('\nNORMALIZED INVERSE FREQUENCY: As you can see below, the class weight with the fewest examples has a larger weight')
print(norm_inv_frequency)
Example where class 4 is the minority class INVERSE FREQUENCY: As you can see below, the class weight with the fewest examples has a larger weight [[0.001 0.00040323 0.00078431 0.00408163]] NORMALIZED INVERSE FREQUENCY: As you can see below, the class weight with the fewest examples has a larger weight [[ 5. 2.01612903 3.92156863 20.40816327]]
Focal Loss
Standard Cross-Entropy treats all samples equally, so easy examples (often from majority classes) dominate the loss. Focal Loss down-weights well-classified examples and focuses training on hard, misclassified examples, which are often from minority classes.
Focal Loss modifies the standard cross-entropy loss (see above) so that the model pays less attention to examples it already classifies correctly and focuses more on those it struggles with. In practice, this means that easy examples—often from majority classes—contribute very little to the loss, while hard examples, which are frequently from minority classes, retain their full influence. This dynamic weighting helps the model learn better decision boundaries for underrepresented classes without being overwhelmed by the abundance of easy, majority-class samples.
How to tune it
Focal Loss introduces two key parameters: the focusing parameter ($\gamma$ ) and the class weight ($\alpha$). The focusing parameter controls how aggressively the loss down-weights easy examples: a value of zero makes it behave like standard cross-entropy, while higher values (commonly around 2) increase the emphasis on hard examples. The class weight $\alpha$ is used to balance the contribution of different classes, often giving more weight to minority classes. A typical starting point is $\gamma = 2$ and $\alpha=0.25$ for minority classes, but these values can be adjusted based on the severity of imbalance and validation performance.
Formula
For a sample with true class ($y$) and predicted probability $p_y$ : $$ \text{FL}(p_y) = - \alpha_y (1 - p_y)^\gamma \log(p_y) $$ where:
- $\alpha_y$ = class weight (balances class frequencies)
- $\gamma $ = focusing parameter (controls how much to down-weight easy examples)
Exercise 6.8: Implementing the focal loss
You are going to implement the focal loss. We will use the same $\alpha$ for all the classes. Follow the tips below!
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
# Reduction is how we aggregate the tensor. Do we average or sum?
# in the end it has to return a single number.
self.reduction = reduction
def forward(self, inputs, targets):
"""
Calculate the focal loss
Parameters
----------
inputs : torch.Tensor
The predicted logits
targets : torch.Tensor
The true labels. (One hot vector)
"""
# calculate the cross entropy (do not average or sum anything yet!)
# Exercise 6.8: Calculat the REGULAR cross-entropy. Do not reduce.
ce_loss = None
# get the probability back by exponentiating the cross entropy loss
# Exercise 6.8: Get the probabilities (exponentiate the negative of the cross entropy loss)
pt = None
# calculate the focal loss
# Exercise 6.8: Check the formula above, implement it and multiply it as the cross entropy
focal_loss = None
# then reduce the tensor with the losses.
return focal_loss.mean() if self.reduction == 'mean' else focal_loss.sum()
6.5 Data Augmentation
Data augmentation is the process by which you can artificially increase the diversity of your training data by applying transformations to existing images. This helps improve model robustness, reduces overfitting, and helps ameliorate the class imbalance problem by generating more diverse samples (though this last claim has come under challenge).
There are many types of data augmentation for computer vision. We will limit ourselves to simple techniques that you can apply directly, and some that have been specifically developed for medical image analysis. To keep this section short we will mention other advanced techniques, but not go into depth.
Basic Data augmentation techniques
Basic data augmentations in computer vision usually involves applying a combination of the following transformations to your training images during training:
- Geometric Transformations: Flips, rotations, translations, scaling, cropping.
- Color and Lighting Adjustments: Brightness, contrast, saturation, hue shifts.
- Noise Injection: Gaussian noise, blur, JPEG compression artifacts.
See below for examples of Single Transformations applied to images.
# @title
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import torch
from skimage import data # Import skimage for the astronaut image
# Load the astronaut image from PIL
try:
img = Image.fromarray(data.astronaut())
print("Astronaut image loaded successfully.")
except Exception as e:
print(f"Error loading image: {e}")
# Fallback to a dummy image if loading fails
img = Image.fromarray(np.random.randint(0, 256, (256, 256, 3), dtype=np.uint8))
print("Using a dummy image as a fallback.")
# Define a few data augmentation transformations
augmentations = {
"Original": transforms.Compose([transforms.ToTensor(), transforms.ToPILImage()]), # Convert back to PIL for consistency
"Random Horizontal Flip": transforms.RandomHorizontalFlip(p=1.0),
"Random Vertical Flip": transforms.RandomVerticalFlip(p=1.0),
"Random Rotation (30 deg)": transforms.RandomRotation(30),
"Color Jitter (Brightness)": transforms.ColorJitter(brightness=0.5),
"Color Jitter (Contrast)": transforms.ColorJitter(contrast=0.5),
"Color Jitter (Saturation)": transforms.ColorJitter(saturation=0.5),
"Color Jitter (Hue)": transforms.ColorJitter(hue=0.2), # Added Hue jitter
"Random Resized Crop": transforms.RandomResizedCrop(224, scale=(0.5, 1.0)),
"Pad (50 pixels)": transforms.Pad(50, fill=0), # Pad with black
"RandomAffine (Translate)": transforms.RandomAffine(0, translate=(0.2, 0.2)),
"RandomAffine (Scale)": transforms.RandomAffine(0, scale=(0.8, 1.2)),
"RandomAffine (Shear)": transforms.RandomAffine(0, shear=30),
"Gaussian Blur": transforms.GaussianBlur(kernel_size=5), # Added Gaussian Blur
# Note: Adding realistic noise can be more complex, but Gaussian Blur is a simple form of noise injection.
}
# Determine grid size (approximate square layout)
n_augmentations = len(augmentations)
n_cols = int(np.ceil(np.sqrt(n_augmentations)))
n_rows = int(np.ceil(n_augmentations / n_cols))
# Create a figure and axes for the grid display
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
axes = axes.flatten() # Flatten the 2D array of axes for easy iteration
# Display original and augmented images
for i, (name, transform) in enumerate(augmentations.items()):
try:
augmented_img = transform(img)
axes[i].imshow(augmented_img)
axes[i].set_title(name)
axes[i].axis('off') # Hide axes
except Exception as e:
print(f"Error applying {name}: {e}")
axes[i].set_title(f"Error: {name}")
axes[i].axis('off')
# Hide any unused subplots
for j in range(i + 1, len(axes)):
axes[j].axis('off')
plt.tight_layout() # Adjust layout to prevent overlapping titles/labels
plt.show()
Astronaut image loaded successfully.
These methods simulate deformations/variations that can occur in real life and hence you can use them to make your neural network "get used" to them. Key is that each method is applied with a certain chance.
In practise we use compositions of these different data augmentations and build pipelines to randomly apply them to images. The example below uses a torchvision pipeline. If you want to have more sophisticated pipelines with different augmentation branches (where certain augmentations are or are not combined), you can use a library called Albumentations.
Example of combinations:
# @title
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import torch
from skimage import data # Import skimage for the astronaut image
# Load the astronaut image from PIL (reusing the loading from the previous cell)
try:
img = Image.fromarray(data.astronaut())
print("Astronaut image loaded successfully.")
except Exception as e:
print(f"Error loading image: {e}")
# Fallback to a dummy image if loading fails
img = Image.fromarray(np.random.randint(0, 256, (256, 256, 3), dtype=np.uint8))
print("Using a dummy image as a fallback.")
# Define a pipeline of random data augmentations
# We'll use some of the transformations from the previous example
augmentation_pipeline = transforms.Compose([
transforms.RandomApply([transforms.RandomHorizontalFlip()], p=0.5),
transforms.RandomApply([transforms.RandomVerticalFlip()], p=0.5),
transforms.RandomApply([transforms.RandomRotation(30)], p=0.5),
transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)], p=0.5),
transforms.RandomApply([transforms.RandomResizedCrop(img.size[0], scale=(0.8, 1.0))], p=0.5), # Crop to original size
transforms.RandomApply([transforms.Pad(20, fill=0)], p=0.2), # Less chance of large padding
transforms.RandomApply([transforms.RandomAffine(0, translate=(0.1, 0.1))], p=0.5),
transforms.RandomApply([transforms.RandomAffine(0, scale=(0.9, 1.1))], p=0.5),
transforms.RandomApply([transforms.RandomAffine(0, shear=10)], p=0.5),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3), # Less chance of blur
transforms.ToTensor(), # Convert to tensor for some transforms
transforms.ToPILImage() # Convert back to PIL for display
])
# Generate 6 randomly augmented samples
num_samples = 6
augmented_samples = [augmentation_pipeline(img) for _ in range(num_samples)]
# Display the samples on a grid
n_cols = 3
n_rows = int(np.ceil(num_samples / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
axes = axes.flatten()
for i, aug_img in enumerate(augmented_samples):
axes[i].imshow(aug_img)
axes[i].set_title(f"Sample {i+1}")
axes[i].axis('off')
# Hide any unused subplots
for j in range(i + 1, len(axes)):
axes[j].axis('off')
plt.tight_layout()
plt.show()
Astronaut image loaded successfully.
Augmentations for medical data
Many of the data augmentation techniques that are applied to regular computer vision, specifically the geometric transformations, can be applied medical data. However, some of the others, such as ColorJitter which changes the values of the color channels in an RGB image, cannot always directly be applied.
When working with data augmentation for medical images you should take into consideration the physical processes by which they are generated.
As an example, I have plotted augmentations for Magnetic Resonance images below. You can see examples of ghosting, spiking, motion, bias field inhomogeneity, and elastic deformations.
# @title
!pip install nilearn
!pip install nibabel
!pip install torchio
Requirement already satisfied: nilearn in /usr/local/lib/python3.12/dist-packages (0.12.1)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from nilearn) (1.5.2)
Requirement already satisfied: lxml in /usr/local/lib/python3.12/dist-packages (from nilearn) (5.4.0)
Requirement already satisfied: nibabel>=5.2.0 in /usr/local/lib/python3.12/dist-packages (from nilearn) (5.3.2)
Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.12/dist-packages (from nilearn) (2.0.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from nilearn) (25.0)
Requirement already satisfied: pandas>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from nilearn) (2.2.2)
Requirement already satisfied: requests>=2.25.0 in /usr/local/lib/python3.12/dist-packages (from nilearn) (2.32.4)
Requirement already satisfied: scikit-learn>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from nilearn) (1.6.1)
Requirement already satisfied: scipy>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from nilearn) (1.16.2)
Requirement already satisfied: typing-extensions>=4.6 in /usr/local/lib/python3.12/dist-packages (from nibabel>=5.2.0->nilearn) (4.15.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.2.0->nilearn) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.2.0->nilearn) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=2.2.0->nilearn) (2025.2)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.25.0->nilearn) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.25.0->nilearn) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.25.0->nilearn) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.25.0->nilearn) (2025.10.5)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=1.4.0->nilearn) (3.6.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas>=2.2.0->nilearn) (1.17.0)
Requirement already satisfied: nibabel in /usr/local/lib/python3.12/dist-packages (5.3.2)
Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.12/dist-packages (from nibabel) (2.0.2)
Requirement already satisfied: packaging>=20 in /usr/local/lib/python3.12/dist-packages (from nibabel) (25.0)
Requirement already satisfied: typing-extensions>=4.6 in /usr/local/lib/python3.12/dist-packages (from nibabel) (4.15.0)
Collecting torchio
Downloading torchio-0.20.23-py3-none-any.whl.metadata (52 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 52.9/52.9 kB 5.3 MB/s eta 0:00:00
Collecting deprecated>=1.2 (from torchio)
Downloading Deprecated-1.2.18-py2.py3-none-any.whl.metadata (5.7 kB)
Requirement already satisfied: einops>=0.3 in /usr/local/lib/python3.12/dist-packages (from torchio) (0.8.1)
Requirement already satisfied: humanize>=0.1 in /usr/local/lib/python3.12/dist-packages (from torchio) (4.13.0)
Requirement already satisfied: nibabel>=3 in /usr/local/lib/python3.12/dist-packages (from torchio) (5.3.2)
Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.12/dist-packages (from torchio) (2.0.2)
Requirement already satisfied: packaging>=20 in /usr/local/lib/python3.12/dist-packages (from torchio) (25.0)
Requirement already satisfied: rich>=10 in /usr/local/lib/python3.12/dist-packages (from torchio) (13.9.4)
Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.12/dist-packages (from torchio) (1.16.2)
Collecting simpleitk!=2.0.*,!=2.1.1.1,>=1.3 (from torchio)
Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.2 kB)
Requirement already satisfied: torch>=1.9 in /usr/local/lib/python3.12/dist-packages (from torchio) (2.8.0+cu126)
Requirement already satisfied: tqdm>=4.40 in /usr/local/lib/python3.12/dist-packages (from torchio) (4.67.1)
Requirement already satisfied: typer>=0.1 in /usr/local/lib/python3.12/dist-packages (from torchio) (0.19.2)
Requirement already satisfied: wrapt<2,>=1.10 in /usr/local/lib/python3.12/dist-packages (from deprecated>=1.2->torchio) (1.17.3)
Requirement already satisfied: typing-extensions>=4.6 in /usr/local/lib/python3.12/dist-packages (from nibabel>=3->torchio) (4.15.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10->torchio) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=10->torchio) (2.19.2)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (3.20.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (1.13.3)
Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (3.5)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (2025.3.0)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (12.6.77)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (12.6.80)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (12.6.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (11.3.0.4)
Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (10.3.7.77)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (12.5.4.2)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (2.27.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (12.6.77)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (12.6.85)
Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (1.11.1.6)
Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.9->torchio) (3.4.0)
Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.1->torchio) (8.3.0)
Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.1->torchio) (1.5.4)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=10->torchio) (0.1.2)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.9->torchio) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.9->torchio) (3.0.3)
Downloading torchio-0.20.23-py3-none-any.whl (194 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.0/194.0 kB 22.7 MB/s eta 0:00:00
Downloading Deprecated-1.2.18-py2.py3-none-any.whl (10.0 kB)
Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (52.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 52.6/52.6 MB 19.0 MB/s eta 0:00:00
Installing collected packages: simpleitk, deprecated, torchio
Successfully installed deprecated-1.2.18 simpleitk-2.5.2 torchio-0.20.23
# @title
import torchio as tio
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nb
import torch # Import torch to use unsqueeze
# Use the FPG dataset
try:
fpg_subject = tio.datasets.FPG()[0] # Get the first subject from the FPG dataset
print("FPG dataset loaded successfully.")
except Exception as e:
print(f"Error loading FPG dataset: {e}")
# Fallback to a dummy subject if loading fails
# Create a dummy tensor (1 channel, 3D spatial)
dummy_tensor = torch.randn(1, 100, 100, 100)
dummy_affine = torch.eye(4)
dummy_image = tio.ScalarImage(tensor=dummy_tensor, affine=dummy_affine)
fpg_subject = tio.Subject(image=dummy_image)
print("Using a dummy subject as a fallback.")
slice_image = fpg['t1']
unaltered_slice = slice_image.data.squeeze()
max_displacement = (20, 20, 5)
# Elastic Deformation. You could use this in regular computer vision too, but I've never seen it used before.
elastic_deformation = tio.RandomElasticDeformation(num_control_points=20,
max_displacement=max_displacement,
locked_borders= 0,
p=1.0)
elastically_deformed_volume = elastic_deformation(slice_image)
# Ghosting
ghosting = tio.RandomGhosting(intensity=2., num_ghosts= 4, p=1.0)
ghosting_volume = ghosting(slice_image)
# Random Spike
spiking = tio.RandomSpike(num_spikes = 1, p = 1.0)
spiked_volume = spiking(slice_image)
# Random Bias field
bias_field = tio.RandomBiasField(p=1.0, coefficients=(-1., 1.), order=3)
inhomogenous_volume = bias_field(slice_image)
# Random Motion
motion = tio.RandomMotion(num_transforms=6, image_interpolation='nearest')
moved_volume = motion(slice_image)
["Original", "Ghosting", "RandomBiasField", "RandomElasticDeformation", "RandomMotion", "RandomSpike"]
augmented_subjects = {
"Original": slice_image,
"Ghosting": ghosting_volume,
"RandomBiasField": inhomogenous_volume,
"RandomElasticDeformation": elastically_deformed_volume,
"RandomMotion": moved_volume,
"RandomSpike": spiked_volume
}
FPG dataset loaded successfully.
# @title
import matplotlib.pyplot as plt
import numpy as np
import torch # Import torch if not already imported
# Assuming augmented_subjects dictionary is available from a previous cell (AEBiEt5Jm5Mq)
# Define the augmentations to display and their order
display_order = ["Original", "Ghosting", "RandomBiasField", "RandomElasticDeformation", "RandomMotion", "RandomSpike"]
n_augmentations_to_display = len(display_order)
n_cols = 3
n_rows = 2 # 2x3 grid
sagittal_slice_idx = 128
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))
axes = axes.flatten()
# Use the specified sagittal slice index
sagittal_slice_idx = 128
for i, name in enumerate(display_order):
if name in augmented_subjects:
augmented_subject = augmented_subjects[name]
# Access the image data tensor from the Torchio Subject and ScalarImage
# Squeeze removes dimensions of size 1 (like batch and channel if they exist)
if 't1' in augmented_subject:
img_data_tensor = augmented_subject['t1'].data.squeeze()
else:
img_data_tensor = augmented_subject.data.squeeze()
# Get the specified sagittal slice (slice along the first spatial dimension)
# Use .detach().cpu().numpy() to get a NumPy array from the tensor
if img_data_tensor.ndim >= 3: # Ensure it's at least 3D spatial data
# Slice along the first spatial dimension (sagittal)
if sagittal_slice_idx < img_data_tensor.shape[0]:
slice_to_display = img_data_tensor[sagittal_slice_idx, :, :].detach().cpu().numpy()
else:
# Fallback if the specified index is out of bounds
slice_to_display = img_data_tensor[-1, :, :].detach().cpu().numpy() # Take the last slice
print(f"Warning: Sagittal slice index {sagittal_slice_idx} out of bounds for {name} (max index {img_data_tensor.shape[0]-1}). Using last slice.")
elif img_data_tensor.ndim == 2:
# Handle cases where squeezing results in 2D data unexpectedly
slice_to_display = img_data_tensor.detach().cpu().numpy()
print(f"Warning: Unexpected 2D tensor dimensions for {name}. Displaying as is.")
else:
# Handle unexpected dimensions
print(f"Warning: Unexpected tensor dimensions for {name}: {img_data_tensor.ndim}")
slice_to_display = np.zeros((100, 100)) # Display a blank image or handle as appropriate
axes[i].imshow(slice_to_display, cmap='bone') # Use 'bone' colormap
axes[i].set_title(name)
axes[i].axis('off') # Hide axes
else:
print(f"Warning: Augmentation '{name}' not found in augmented_subjects.")
axes[i].set_title(f"{name} (Not Found)")
axes[i].axis('off')
# Hide any unused subplots (shouldn't be any with a 2x3 grid and 6 augmentations)
for j in range(n_augmentations_to_display, len(axes)):
axes[j].axis('off')
plt.tight_layout()
plt.show()
Advanced Data Augmentation Techniques
Though it is beyond the scope of this tutorial to go into these techniques, I do think it is important to mention them.
CutMix/MixUp
CutMix and MixUp are advanced data augmentation techniques that operate by combining pairs of images and their corresponding labels during training. MixUp linearly interpolates both images and their one-hot labels. For example, a mixed image might be $0.8 \times \text{image}_1 + 0.2 \times \text{image}_2$, with a corresponding label that is $0.8 \times \text{label}_1 + 0.2 \times \text{label}_2$. CutMix is a variation where a patch is cut from one image and pasted onto another, and the labels are mixed proportionally to the areas of the two images in the combined result.
AutoAugment/RandAugment
AutoAugment and RandAugment are advanced data augmentation techniques that move beyond manually selecting and tuning augmentation policies. AutoAugment learns an optimal data augmentation policy from the data itself using a search algorithm. This policy consists of a set of augmentation operations (like rotation, shearing, or color distortion) and the probabilities and magnitudes with which to apply them. RandAugment simplifies this by removing the search phase; it randomly selects a fixed number of augmentation operations from a predefined set and applies them with magnitudes sampled from a range. These methods can lead to significant improvements in model accuracy by finding effective augmentation strategies that might not be obvious through manual tuning.
6.6 Uncertainty Estimation in Deep Learning
Next, we are going to dive into uncertainty estimation, which is very important, but a bit more technical and abstract. Before we dive into uncertainty estimation, we will have to know what uncertainty is. We will have a (very) brief introduction into the Bayesian approach to statistics. Then we will explain a regularization technique called dropout, and explain how it connects to uncertainty estimation in neural networks.
Frequentist vs Bayesian Perspectives on Uncertainty in Predictions
Most people are familiar with the frequentist view of probability: it's about long-run frequencies over repeated experiments. In this framework, model parameters—like the weights in a neural network, or the mean and variance in a normal distribution are treated as fixed but unknown. Uncertainty arises from the variability in the data: if you trained the model on a different dataset, you might get different weights and predictions.
When making predictions, frequentist models typically output a point estimate (e.g., the most likely class) and sometimes a confidence score derived from the softmax output. However, this score is not a true measure of uncertainty, it reflects relative likelihood, not how confident the model is in its prediction. A model can output high softmax probabilities even when it's extrapolating to unfamiliar data. Moreover, frequentist methods don't naturally provide a way to express uncertainty about new predictions beyond calibration or ensemble methods.
The Bayesian Perspective
The Bayesian perspective is fundamentally different:
- Parameters are random variables. We express our uncertainty about them using a probability distribution.
- We start with a prior (our belief before seeing data), then update it with observed data using Bayes' theorem:
- The result is a posterior distribution over parameters, which captures our updated uncertainty.
But the real power of Bayesian inference lies in its treatment of predictive uncertainty. For a new input $x^*$, we don't just compute a single prediction, we integrate over all plausible models (i.e., parameter configurations) weighted by their posterior probability. In other words, we average predictions across all possible models, weighted by how likely each model is given the data:
$$p(y^* | x^*, D) = \int p(y^* | x^*, \theta) \, p(\theta | D) \, d\theta$$This is the predictive distribution, and it captures our uncertainty about the output for $x^*$. From this distribution, we can derive:
- Predictive mean: the expected output.
- Predictive variance: how much predictions vary across plausible models.
- Credible interval: a Bayesian analogue to the confidence interval, which tells us the range within which a new observation is likely to fall with a given probability (e.g., 95%).
The problem is that if we were to do this for neural networks, we would have to store and estimate a lot of additional variables. This is not computationally efficient. We will introduce a technique to estimate this posterior predictive distribution later.
Dropout
Dropout is a regularization technique (you can add it to your toolbelt too!) introduced to prevent overfitting in neural networks. During training, dropout randomly “drops” (sets to zero) a fraction of neurons in a layer on each forward pass. This forces the network to learn redundant representations, making it more robust and less reliant on specific neurons. The key idea is that at each training step, we subsample a different sub-network by randomly removing (zero-ing out) neurons in the network.
Mathematically, for a layer with activations $h$, dropout applies a binary mask $z \sim \text{Bernoulli}(p)$, such that $\tilde{h} = z \odot h$ where $p$ is the keep probability and $\odot$ denotes element-wise multiplication.
In addition to regularization, work by Yarin Gal has shown that you can use dropout to approximate the posterior predictive distribution.
Dropout Variational Inference and Predictive Uncertainty
In practice, computing the full posterior is intractable for deep neural networks. Dropout variational inference (or monte carlo dropout/MC-Dropout) offers a scalable approximation:
- We apply dropout at test time and perform multiple stochastic forward passes.
- Each pass samples a different sub-network, approximating a draw from the posterior.
- The resulting predictions form an empirical predictive distribution.
From this, we can compute a credible interval over predictions, not just over parameters. This gives us a principled way to say:
"Given the data and model uncertainty, we can be 95% certain that the output for this input lies within this range."
This is especially useful in safety-critical applications (e.g., medical diagnosis, autonomous driving), where knowing when the model is uncertain can be as important as the prediction itself.
Important Caveat
While dropout variational inference is computationally efficient and often works well in practice, it's worth noting that it approximates a specific form of posterior distribution (Bernoulli dropout on weights), which may not perfectly capture the true posterior uncertainty. It's a practical approximation rather than an exact Bayesian solution, but it provides a good balance between theoretical grounding and computational feasibility.
Why is Uncertainty Estimation Important?
For engineers deploying deep learning models, understanding uncertainty is crucial for building robust and reliable systems:
- Reliability: In critical applications (like autonomous driving, medical diagnosis, or fraud detection), a model needs to know when it doesn't know. High uncertainty might signal that a prediction is unreliable and requires human review or a fallback mechanism.
- Out-of-Distribution Detection: Models often perform poorly on data that is significantly different from their training data (out-of-distribution data). High uncertainty is a strong indicator that the input might be out-of-distribution, allowing the system to handle it appropriately.
- Active Learning: Uncertainty can guide data collection. If a model is highly uncertain about predictions on certain types of data, it suggests that acquiring and labeling more data in those areas could significantly improve the model's performance.
- Model Improvement: Analyzing where and why a model is uncertain can provide insights into its weaknesses and guide further model development or data annotation efforts. It allows you to select data that would be most beneficial to annotate, a field called active learning.
Example: Uncertainty estimation on Cifar-10
In this example, we will show how uncertainty estimation is implemented. As you have read above the setup is fairly simple: Add dropout to your network, keep the dropout on during inference.
This means that on each datapoint for which we want to get an estimate of the uncertainty, we do inference $N$ amount of times. You can try to re-run this example, however training will take a bit long. So it is best if you just leave it as is.
NOTE: we are using the cifar-10 dataset, which is simpler than the cifar-100 dataset.
NOTE: Do not run the code, just study it!
It takes quite a while to run!
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import Callable, List
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
class SimpleCNNWithDropout(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNNWithDropout, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout1 = nn.Dropout(0.25)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout2 = nn.Dropout(0.25)
self.flatten = nn.Flatten()
# Calculate the size of the flattened features
# Input size: 32x32 -> after pool1 (stride 2): 16x16 -> after pool2 (stride 2): 8x8
flattened_size = 64 * 8 * 8
self.fc1 = nn.Linear(flattened_size, 128)
self.relu3 = nn.ReLU()
self.dropout3 = nn.Dropout(0.5) # Higher dropout for FC layer
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.dropout1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.dropout2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu3(x)
x = self.dropout3(x)
x = self.fc2(x)
return x
Helper functions for training and validation
# @title
def calculate_accuracy(labels: torch.Tensor, predictions: torch.Tensor) -> float:
"""
Calculate the accuracy of the predictions.
Parameters
----------
labels : torch.Tensor
The true labels.
predictions : torch.Tensor
The predicted logits from the model.
Returns
-------
float
The accuracy of the predictions as a percentage.
"""
# Get the index of the max log-probability
_, predicted = torch.max(predictions.data, 1)
total = labels.size(0)
correct = (predicted == labels).sum().item()
accuracy = correct / total
return accuracy
def train_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimizer: optim.Optimizer, loader: DataLoader, device: str) -> List[float]:
"""
Train the model for a single EPOCH on the data.
Parameters
----------
model : nn.Module
The model to train.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
optimizer : optim.Optimizer
The optimizer to use.
loader : DataLoader
The data loader for the training set.
device : str
The device to train on ('cuda' or 'cpu').
Returns
-------
Tuple[List[float], List[float]]
A tuple containing a list of losses and a list of accuracies for each iteration.
"""
model.train() # Set model to training mode (dropout is active)
train_losses = []
train_accuracies = []
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
predictions = model(images)
loss = loss_function(predictions, labels)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
train_accuracies.append(calculate_accuracy(labels, predictions))
return train_losses, train_accuracies
def validate_epoch(model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], loader: DataLoader, device: str) -> List[float]:
"""
Validate the model for a single EPOCH on the data.
Parameters
----------
model : nn.Module
The model to validate.
loss_function : Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
The loss function to use.
loader : DataLoader
The data loader for the validation set.
device : str
The device to validate on ('cuda' or 'cpu').
Returns
-------
Tuple[List[float], List[float]]
A tuple containing a list of losses and a list of accuracies for each iteration.
"""
model.eval() # Set model to evaluation mode (dropout is inactive by default)
val_losses = []
val_accuracies = []
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
predictions = model(images)
loss = loss_function(predictions, labels)
val_losses.append(loss.item())
val_accuracies.append(calculate_accuracy(labels, predictions))
return val_losses, val_accuracies
def get_one_example_per_class(dataset):
"""
Samples one example from each class in the dataset.
Args:
dataset: A PyTorch Dataset object (e.g., torchvision.datasets.CIFAR10).
Returns:
A list of (image, label) tuples, one for each class.
"""
class_indices = {}
# Iterate through the dataset to find the first index for each class
for i, (_, label) in enumerate(dataset):
if label not in class_indices:
class_indices[label] = i
# Stop once we have found an index for all classes
if len(class_indices) == len(dataset.classes):
break
# Get the actual samples using the collected indices
samples = []
# Sort by class label for consistent order
for label in sorted(class_indices.keys()):
index = class_indices[label]
samples.append(dataset[index])
return samples
Dropout Variational Inference: You run inference for your prediction N times, and calculate the mean and the variance from them.
def dropout_variational_inference(model: torch.nn.Module, data: list, num_repetitions: int, device: str):
"""
Performs dropout variational inference on sampled data points.
Args:
model: The trained PyTorch model with dropout layers.
data: A list of (image, label) tuples for inference.
num_repetitions: The number of forward passes to perform for each data point.
device: The device to perform inference on ('cuda' or 'cpu').
Returns:
A tuple containing:
- sampled_data: The input data used for inference.
- mean_predictions: Mean of predictions across repetitions for each data point.
- variance_predictions: Variance of predictions across repetitions for each data point.
"""
model.train() # Set model to training mode to enable dropout during inference
model.to(device)
all_predictions = []
original_labels = []
input_images = []
for image, label in data:
# Ensure image has a batch dimension and move to device
image = image.unsqueeze(0).to(device)
original_labels.append(label)
input_images.append(image.squeeze(0).cpu()) # Store original image (without batch dim, on CPU)
predictions_for_sample = []
for _ in range(num_repetitions):
with torch.no_grad(): # No need to calculate gradients during inference
output = model(image)
# Store raw logits or probabilities depending on what you want to analyze
predictions_for_sample.append(output.squeeze(0).cpu()) # Remove batch dim, move to CPU
# Stack predictions for the current sample across repetitions
predictions_for_sample_stacked = torch.stack(predictions_for_sample, dim=0) # Shape (num_repetitions, num_classes)
all_predictions.append(predictions_for_sample_stacked)
# Stack predictions for all sampled data points
# Shape (num_samples, num_repetitions, num_classes)
all_predictions_stacked = torch.stack(all_predictions, dim=0)
# Calculate mean and variance across repetitions (dim=1)
mean_predictions = torch.mean(all_predictions_stacked, dim=1) # Shape (num_samples, num_classes)
variance_predictions = torch.var(all_predictions_stacked, dim=1) # Shape (num_samples, num_classes)
# Return original data, mean predictions, and variance
# Re-package sampled data for clarity
sampled_data_output = [(input_images[i], original_labels[i]) for i in range(len(data))]
return sampled_data_output, mean_predictions, variance_predictions, all_predictions_stacked
Training Function
# @title
# Assuming SimpleCNNWithDropout, train_epoch, validate_epoch, calculate_accuracy, initialize_weights,
# trainloader, testloader, and device are available from previous cells
def run_training(
model: nn.Module,
learning_rate: float,
batch_size: int,
num_epochs: int,
train_loader: DataLoader,
validation_loader: DataLoader,
device: str) -> List[List[float]]:
# define the loss function
loss_function = nn.CrossEntropyLoss()
# define the optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
# train the model
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = [], [], [], []
for epoch in range(num_epochs):
print(f'#---------------- START EPOCH {epoch+1} ----------------#')
train_losses, train_accuracies = train_epoch(model, loss_function, optimizer, train_loader, device)
val_losses, val_accuracies = validate_epoch(model, loss_function, validation_loader, device)
# average epoch accuracy and loss
train_accuracy = sum(train_accuracies)/len(train_accuracies)
train_loss = sum(train_losses)/len(train_losses)
validation_accuracy = sum(val_accuracies)/len(val_accuracies)
validation_loss = sum(val_losses)/len(val_losses)
# save them for a plot later
train_epoch_accuracies.append(train_accuracy)
train_epoch_losses.append(train_loss)
val_epoch_accuracies.append(validation_accuracy)
val_epoch_losses.append(validation_loss)
# round the print statements to three decimal places
train_accuracy = round(train_accuracy, 3)
train_loss = round(train_loss, 3)
validation_accuracy = round(validation_accuracy, 3)
validation_loss = round(validation_loss, 3)
# report it out
print(f"Train Accuracy: {train_accuracy}, Train Loss: {train_loss}")
print(f"Validation Accuracy: {validation_accuracy}, Validation Loss: {validation_loss}")
print(f'#---------------- END EPOCH {epoch+1} -----------------#')
print('\n')
print('========================== FINISHED TRAINING ==========================')
print('\n')
return train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
# Data Loading for CIFAR-10 (adjusting mean/std for CIFAR-10)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 mean and std
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 mean and std
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
# Instantiate the model
simple_cnn_model = SimpleCNNWithDropout(num_classes=10).to(device)
# Initialize weights (optional but recommended)
# Assuming initialize_weights function is defined in a previous cell
# initialize_weights(simple_cnn_model)
# Set hyperparameters
learning_rate = 0.001
batch_size = 128 # Using batch size from DataLoader instantiation in a previous cell
num_epochs = 20
# Run the training
train_epoch_accuracies, train_epoch_losses, val_epoch_accuracies, val_epoch_losses = run_training(
simple_cnn_model,
learning_rate,
batch_size,
num_epochs,
trainloader,
testloader,
device
)
Device: cuda
100%|██████████| 170M/170M [00:02<00:00, 60.4MB/s]
#---------------- START EPOCH 1 ----------------# Train Accuracy: 0.355, Train Loss: 1.755 Validation Accuracy: 0.512, Validation Loss: 1.37 #---------------- END EPOCH 1 -----------------# #---------------- START EPOCH 2 ----------------# Train Accuracy: 0.459, Train Loss: 1.492 Validation Accuracy: 0.554, Validation Loss: 1.251 #---------------- END EPOCH 2 -----------------# #---------------- START EPOCH 3 ----------------# Train Accuracy: 0.499, Train Loss: 1.393 Validation Accuracy: 0.6, Validation Loss: 1.148 #---------------- END EPOCH 3 -----------------# #---------------- START EPOCH 4 ----------------# Train Accuracy: 0.523, Train Loss: 1.333 Validation Accuracy: 0.628, Validation Loss: 1.082 #---------------- END EPOCH 4 -----------------# #---------------- START EPOCH 5 ----------------# Train Accuracy: 0.541, Train Loss: 1.29 Validation Accuracy: 0.636, Validation Loss: 1.036 #---------------- END EPOCH 5 -----------------# #---------------- START EPOCH 6 ----------------# Train Accuracy: 0.552, Train Loss: 1.253 Validation Accuracy: 0.652, Validation Loss: 1.009 #---------------- END EPOCH 6 -----------------# #---------------- START EPOCH 7 ----------------# Train Accuracy: 0.558, Train Loss: 1.233 Validation Accuracy: 0.664, Validation Loss: 0.972 #---------------- END EPOCH 7 -----------------# #---------------- START EPOCH 8 ----------------# Train Accuracy: 0.571, Train Loss: 1.208 Validation Accuracy: 0.674, Validation Loss: 0.942 #---------------- END EPOCH 8 -----------------# #---------------- START EPOCH 9 ----------------# Train Accuracy: 0.578, Train Loss: 1.185 Validation Accuracy: 0.668, Validation Loss: 0.935 #---------------- END EPOCH 9 -----------------# #---------------- START EPOCH 10 ----------------# Train Accuracy: 0.586, Train Loss: 1.167 Validation Accuracy: 0.678, Validation Loss: 0.93 #---------------- END EPOCH 10 -----------------# #---------------- START EPOCH 11 ----------------# Train Accuracy: 0.596, Train Loss: 1.148 Validation Accuracy: 0.695, Validation Loss: 0.898 #---------------- END EPOCH 11 -----------------# #---------------- START EPOCH 12 ----------------# Train Accuracy: 0.599, Train Loss: 1.139 Validation Accuracy: 0.701, Validation Loss: 0.875 #---------------- END EPOCH 12 -----------------# #---------------- START EPOCH 13 ----------------# Train Accuracy: 0.604, Train Loss: 1.122 Validation Accuracy: 0.692, Validation Loss: 0.876 #---------------- END EPOCH 13 -----------------# #---------------- START EPOCH 14 ----------------# Train Accuracy: 0.608, Train Loss: 1.115 Validation Accuracy: 0.703, Validation Loss: 0.868 #---------------- END EPOCH 14 -----------------# #---------------- START EPOCH 15 ----------------# Train Accuracy: 0.616, Train Loss: 1.096 Validation Accuracy: 0.71, Validation Loss: 0.844 #---------------- END EPOCH 15 -----------------# #---------------- START EPOCH 16 ----------------# Train Accuracy: 0.617, Train Loss: 1.092 Validation Accuracy: 0.709, Validation Loss: 0.849 #---------------- END EPOCH 16 -----------------# #---------------- START EPOCH 17 ----------------# Train Accuracy: 0.618, Train Loss: 1.093 Validation Accuracy: 0.712, Validation Loss: 0.834 #---------------- END EPOCH 17 -----------------# #---------------- START EPOCH 18 ----------------# Train Accuracy: 0.626, Train Loss: 1.074 Validation Accuracy: 0.714, Validation Loss: 0.822 #---------------- END EPOCH 18 -----------------# #---------------- START EPOCH 19 ----------------# Train Accuracy: 0.626, Train Loss: 1.067 Validation Accuracy: 0.716, Validation Loss: 0.821 #---------------- END EPOCH 19 -----------------# #---------------- START EPOCH 20 ----------------# Train Accuracy: 0.631, Train Loss: 1.055 Validation Accuracy: 0.719, Validation Loss: 0.805 #---------------- END EPOCH 20 -----------------# ========================== FINISHED TRAINING ==========================
Doing Dropout Variational Inference for 30 inference steps
# @title
# Perform dropout variational inference
# Assuming simple_cnn_model, testset, and device are available
sampled_data = get_one_example_per_class(testset)
num_repetitions = 30 # Number of forward passes for uncertainty estimation
# Assuming dropout_variational_inference function is defined in a previous cell (C8yrzhXZAD7A)
sampled_data_output, mean_predictions, variance_predictions, individual_predictions = dropout_variational_inference(
simple_cnn_model, sampled_data, num_repetitions, device
)
print("\nDropout Variational Inference Results:")
print(f"Number of repetitions: {num_repetitions}")
# Display results (mean and variance for each sampled example)
print("\nMean Predictions (Logits):")
# Display mean predictions with corresponding true labels
for i, (image, label) in enumerate(sampled_data_output):
print(f" Class: {testset.classes[label]} (Label: {label}), Mean Logits: {mean_predictions[i]}")
print("\nVariance of Predictions (Logits):")
# Display variance with corresponding true labels
for i, (image, label) in enumerate(sampled_data_output):
print(f" Class: {testset.classes[label]} (Label: {label}), Variance Logits: {variance_predictions[i]}")
# You can further analyze these mean and variance values to understand uncertainty.
# For example, calculate predictive entropy or plot credible intervals.
Dropout Variational Inference Results:
Number of repetitions: 30
Mean Predictions (Logits):
Class: airplane (Label: 0), Mean Logits: tensor([ 3.5094, 0.8242, -1.3315, -3.3670, -2.9582, -6.6626, -4.5293, -7.5883,
4.3340, -0.8344])
Class: automobile (Label: 1), Mean Logits: tensor([ 0.0550, 2.5855, -2.1015, 1.2792, -5.0368, 0.2036, -2.7366, -1.1705,
-3.5808, 0.6126])
Class: bird (Label: 2), Mean Logits: tensor([-0.4331, -3.9297, 1.6294, 0.6065, 0.3479, 0.2755, 0.5215, -0.4175,
-4.2110, -2.1361])
Class: cat (Label: 3), Mean Logits: tensor([-1.1487, -1.2076, -1.3034, 3.3054, -1.7904, 2.3609, -0.0461, -1.9562,
-0.5383, -2.2048])
Class: deer (Label: 4), Mean Logits: tensor([ 1.2106, -5.3113, 0.6577, 0.8817, 0.9424, -0.4755, -0.6741, -3.6512,
1.8178, -2.8601])
Class: dog (Label: 5), Mean Logits: tensor([-3.2166, -3.3860, 0.4186, 2.4774, 0.9617, 2.7145, -0.2068, 1.2996,
-4.4930, -3.4073])
Class: frog (Label: 6), Mean Logits: tensor([-5.3168, -4.3779, 2.4510, 2.1314, 1.8707, -1.7994, 7.3938, -7.6130,
-4.3833, -7.1434])
Class: horse (Label: 7), Mean Logits: tensor([-1.0943, -0.9160, -1.4390, 0.6503, -1.3384, 1.0053, -6.0266, 4.4254,
-5.8948, 1.0161])
Class: ship (Label: 8), Mean Logits: tensor([ 4.1348, 5.7241, -7.5598, -7.3506, -11.4690, -13.1288, -11.6456,
-13.7956, 8.5602, 1.8960])
Class: truck (Label: 9), Mean Logits: tensor([-2.0909, 6.5038, -6.7074, -4.1259, -9.7056, -6.8976, -6.9696, -6.8122,
-3.0989, 9.5639])
Variance of Predictions (Logits):
Class: airplane (Label: 0), Variance Logits: tensor([1.0973, 3.0408, 2.3093, 2.1334, 2.9185, 4.8024, 3.2409, 7.8506, 1.9107,
1.9083])
Class: automobile (Label: 1), Variance Logits: tensor([4.0170, 4.6201, 2.3189, 3.1160, 6.1984, 4.4562, 3.6094, 2.3936, 2.5049,
4.3916])
Class: bird (Label: 2), Variance Logits: tensor([0.9645, 3.3719, 0.4847, 0.6492, 0.5930, 0.9107, 0.7997, 1.2138, 4.7227,
1.7279])
Class: cat (Label: 3), Variance Logits: tensor([0.8865, 2.3110, 0.7305, 1.0724, 0.8719, 1.0963, 1.2275, 1.1656, 1.6821,
2.2675])
Class: deer (Label: 4), Variance Logits: tensor([0.9186, 3.1074, 0.4609, 0.3857, 0.9192, 0.8598, 0.9527, 1.4711, 0.9343,
1.3947])
Class: dog (Label: 5), Variance Logits: tensor([1.5671, 2.1630, 0.3538, 0.4313, 0.7279, 0.7206, 1.2008, 0.8842, 2.1373,
2.1359])
Class: frog (Label: 6), Variance Logits: tensor([3.3867, 4.7547, 1.9662, 2.1786, 2.7900, 1.8375, 3.5621, 7.6038, 5.1007,
6.8325])
Class: horse (Label: 7), Variance Logits: tensor([3.0044, 5.4386, 1.2361, 2.5925, 2.4964, 3.1557, 7.9031, 3.0750, 6.4222,
3.7980])
Class: ship (Label: 8), Variance Logits: tensor([ 2.6204, 8.4536, 8.5790, 9.1830, 17.4503, 22.1709, 13.0570, 28.8751,
9.4742, 4.3461])
Class: truck (Label: 9), Variance Logits: tensor([ 4.1846, 5.7801, 6.6072, 4.7271, 11.3084, 10.6725, 11.7905, 15.9101,
6.1381, 7.2400])
Plotting the results: The plot below shows you the uncertainty estimates for each datapoint. In each plot you can see the estimate for each run and each class. The big red dot is the mean prediction.
# Ensure these variables are available
try:
sampled_data_output
individual_predictions
testset.classes
except NameError:
print("Error: Required variables (sampled_data_output, individual_predictions, testset.classes) are not available.")
print("Please run the previous cells for model training and dropout variational inference.")
# Exit or return if variables are not defined to prevent further errors
# For now, we'll just print an error and not proceed with plotting
raise
num_samples, num_repetitions, num_classes = individual_predictions.shape
class_labels = testset.classes # Assuming testset.classes is available
# Apply softmax to get probabilities for each repetition
individual_probabilities = torch.softmax(individual_predictions, dim=-1).cpu().numpy() # (num_samples, num_repetitions, num_classes)
# Calculate mean probabilities across repetitions
mean_probabilities = np.mean(individual_probabilities, axis=1) # (num_samples, num_classes)
# Set up the plot grid (one subplot per sampled image)
n_cols = 3 # Adjust as needed for layout
n_rows = (num_samples + n_cols - 1) // n_cols # Calculate rows based on num_samples and cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 6, n_rows * 5))
axes = axes.flatten() # Flatten for easy iteration
for i in range(num_samples):
ax = axes[i]
true_label = sampled_data_output[i][1] # Get true label for the i-th sampled image
true_class_name = class_labels[true_label]
# Get probabilities and mean probabilities for the current sample
sample_probabilities = individual_probabilities[i, :, :] # (num_repetitions, num_classes)
sample_mean_probs = mean_probabilities[i, :] # (num_classes,)
# Determine if the mean prediction is correct and get predicted class name
predicted_label = np.argmax(sample_mean_probs)
predicted_class_name = class_labels[predicted_label]
is_correct = (predicted_label == true_label)
# Plotting for each class
for class_idx in range(num_classes):
# Get probabilities for the current class across all repetitions
class_probs_across_repetitions = sample_probabilities[:, class_idx] # (num_repetitions,)
# Scatter plot of individual probabilities with smaller dots
jitter = np.random.rand(num_repetitions) * 0.2 - 0.1 # Small random jitter
ax.scatter(class_idx + jitter, class_probs_across_repetitions, s=5, alpha=0.5, color='blue') # Reduced dot size
# Plot mean probability for the current class
ax.plot(class_idx, sample_mean_probs[class_idx], 'ro', markersize=5) # Red dot for mean
# Customize plot
ax.set_xticks(range(num_classes))
ax.set_xticklabels(class_labels, rotation=90)
ax.set_ylabel('Probability')
# Update title to show true and predicted class
ax.set_title(f'Sample {i+1} (True: {true_class_name}, Predicted: {predicted_class_name})')
ax.set_ylim(-0.05, 1.05) # Set y-axis limits for probabilities
ax.grid(True, linestyle='--', alpha=0.6)
# Color the x-axis tick label for the ground truth class
xtick_labels = ax.get_xticklabels()
xtick_labels[true_label].set_color('green' if is_correct else 'red')
# Hide any unused subplots
for j in range(i + 1, len(axes)):
fig.delaxes(axes[j])
plt.tight_layout()
plt.show()
This was an introduction to uncertainty in deep learning using MC-dropout. There is much more to this topic, in fact we could spend several lectures on it. But if anybody ever mentions uncertainty when discussing deep learning, you now know where to start!
Conclusion
We have covered a lot of ground in this tutorial. We've started by implementing a basic convolutional neural network. Then we implemented more advanced convolutional neural network architectures. We discussed attention mechanisms, loss functions, data augmentation, and uncertainty estimation. Pad yourself on the shoulder, it's a lot of work.
In the next tutorial we will discuss about how to evaluate results in a classification setting.
Author: Riaan Zoetmulder