# TP1: Pytorch Basics

## Goals
- Understand implementation of a classification task
 - Data formatting and manipulation
 - Architecture implementation
 - Training/evaluation
 - Load/save models

- Adapt an architecture to a new dataset (transfer learning / fine-tuning)


In a first part, implementation code is given for the training of a LeNet-like architecture used for a classification task using the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) (only digits 0 to 4 are considered).
In a second part, you have to adapt this implementation for two use cases: transfer learning and fine-tuning. The goal is to adapt to a new set of classes: from digits 0 to 4, to digits 0 to 10.

## I - Formatting the dataset

In [None]:
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Dataset

Documentation on [Datasets and DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) 
First, we create a custom Dataset class to handle the MNIST dataset. 
A Dataset object must include two functions: 
```__len__()```: which gives the number of samples in the dataset 
```__getitem__(i)```: which return necessary information about the ith sample (generally input and expected output)

The original training set (60,000 images) is split into a new training set (50,000 images) and a validation set (10,000 images)

**Question 1**: Why do we add a validation split, in addition to the test set?


In [None]:
class MNISTDataset(Dataset):
 def __init__(self, set_name, labels=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), num_samples_per_label=None):
 self.set_name = set_name
 self.labels = labels
 self.mnist = MNIST(root="./cache", train=set_name in ("train", "val"), download=True)
 self.samples = self.format_samples(num_samples_per_label)

 def format_samples(self, num_samples_per_label=None):
 samples = list()
 match self.set_name:
 case "train":
 indices = list(range(5000, 55000))
 case "val":
 indices = list(range(0, 5000)) + list(range(55000, 60000))
 case "test":
 indices = list(range(0, 10000))

 num_samples_per_label_dict = dict()
 for label in self.labels:
 num_samples_per_label_dict[label] = 0

 for i in indices:
 label = int(self.mnist.targets[i])
 if label not in self.labels:
 continue
 if num_samples_per_label is not None and num_samples_per_label_dict[label] >= num_samples_per_label:
 continue
 image = self.mnist.data[i].to(torch.float).unsqueeze(0)
 samples.append({
 "image": image,
 "label": self.mnist.targets[i]
 })
 num_samples_per_label_dict[label] += 1
 return samples

 def __len__(self):
 return len(self.samples)

 def __getitem__(self, idx):
 return self.samples[idx]["image"], self.samples[idx]["label"]

In [None]:
labels = (0, 1, 2, 3, 4)

# Dataset instantiations
train_dataset = MNISTDataset(set_name="train", labels=labels)
val_dataset = MNISTDataset(set_name="val", labels=labels)
test_dataset = MNISTDataset(set_name="test", labels=labels)

In [None]:
# len(dataset) is a shortcut for dataset.__len__()
print(f"# samples for training: {len(train_dataset)}")
print(f"# samples for validation: {len(val_dataset)}")
print(f"# samples for test: {len(test_dataset)}")

In [None]:
# One can use matplotlib package to show the first training sample

import matplotlib.pyplot as plt

# dataset[i] is a shortcut for dataset.__getitem__(i)
image, label = train_dataset[0]
plt.figure()
plt.axis('off')
# Permutation required to go from Pytorch format (C, H, W) to matplotlib format (H, W, C)
plt.imshow(image.permute(1, 2, 0), cmap="gray")
plt.tight_layout()
plt.show()
print(f"Target class: {label}")

## II - Architecture implementation

We will use a modified version of LeNet-5 architecture, which takes as input grayscaled images of size (1, 28, 28). 
Documentation for [Layers and Losses](https://pytorch.org/docs/stable/nn.html)

In [None]:
from torch import nn

class LeNet(nn.Module):
 def __init__(self):
 super(LeNet, self).__init__()
 self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, stride=1, padding=1)
 self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
 self.fc1 = nn.Linear(in_features=400, out_features=1024)
 self.fc2 = nn.Linear(in_features=1024, out_features=84)
 self.fc3 = nn.Linear(in_features=84, out_features=5)

 self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

 @property
 def device(self):
 return next(self.parameters()).device

 def forward(self, x):
 out = torch.tanh(self.conv1(x))
 out = self.max_pool(out)
 out = torch.tanh(self.conv2(out))
 out = self.max_pool(out)
 out = out.reshape(out.size(0), -1) # flatten the representation (from 2D image to 1D vector)
 out = torch.tanh(self.fc1(out))
 out = torch.tanh(self.fc2(out))
 out = self.fc3(out)
 return out

**Question 2**: How many kernels are used in conv1 and conv2?

**Question 3**: What is the decision layer?

In [None]:
# Check if GPU available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Instantiation of the model
net = LeNet().to(device=device)

from torchsummary import summary
summary(net, (1, 28, 28))

**Question 4**: What is the meaning of the "-1" value in the output shape?

**Question 5**: Which layers are parametric?

In [None]:
print(list(net.conv1.named_parameters()))

**Question 6**: How many tensors of weights are stored for the conv1 layer?

In [None]:
image, label = test_dataset[0]
image = image.to(device)

# inference (forward pass)
net.eval()
with torch.inference_mode():
 output = net(image.unsqueeze(0))
 print(output, output.size())

 output = torch.softmax(output, dim=1)
 print(output, output.size())

**Question 7**: What is the goal of the softmax function?

**Question 8**: What is the meaning of the obtained values?

**Question 9**: What would be the predicted class?

## III - Training

Training consists in iteratively training on the training set and evaluating on the validation set to see the evolution of the performance on unseen data. One must define appropriate function for training and evaluation. The main difference lies in the computation of loss, gradients and backpropagation, which is only performed at training time.

In [None]:
import numpy as np
from tqdm import tqdm

# Training hyperparameters
num_epochs = 25
batch_size = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

# Loss
loss_fn = torch.nn.CrossEntropyLoss()

# A dataloader is an iterator over the dataset. It is useful to perform an epoch.
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# function to compute Top-1 accuracy metric
def compute_top1_acc(prediction, ground_truth):
 # prediction (B, N), ground_truth (B)
 best_prediction = torch.argmax(prediction, dim=1)
 return torch.mean(torch.eq(best_prediction, ground_truth), dtype=torch.float)

In [None]:
# function to perform batch-gradient descent
def train_batch(x, y, net, optimizer, loss_function):
 x, y = x.to(net.device), y.to(net.device) # put model weights and inputs on the same device (CPU/GPU)
 optimizer.zero_grad() # zero the gradient buffers
 output = net(x) # inference (forward-pass)
 loss = loss_function(output, y) # compute loss
 loss.backward() # compute gradients (backward-pass)
 optimizer.step() # apply gradients (backward-pass)
 top1_acc = compute_top1_acc(output, y) # compute metric
 return loss.item(), top1_acc.item()

In [None]:
# function to train over all the training samples through batch gradient descent
def train_epoch(dataloader, net, optimizer, loss_fn):
 epoch_loss = list()
 epoch_top1_acc = list()
 net.train()
 progress_bar = tqdm(dataloader)
 for x, y in progress_bar:
 progress_bar.set_description("Training")
 batch_loss, batch_top1_acc = train_batch(x, y, net, optimizer, loss_fn)
 epoch_loss.append(batch_loss)
 epoch_top1_acc.append(batch_top1_acc)
 current_loss = np.mean(epoch_loss)
 current_top1_acc = 100 * np.mean(epoch_top1_acc)
 return current_loss, current_top1_acc

In [None]:
# function to evaluate performance over a batch (forward pass only)
def eval_batch(x, y, net):
 x, y = x.to(net.device), y.to(net.device)
 output = net(x)
 top1_acc = compute_top1_acc(output, y)
 return top1_acc.item()

In [None]:
# function to evaluate performance over a whole set (forward pass only)
def eval(dataloader, net):
 top1_acc = list()
 net.eval()
 with torch.inference_mode(): # prevent tracking gradient-related operation
 for x, y in dataloader:
 batch_top1_acc = eval_batch(x, y, net)
 top1_acc.append(batch_top1_acc)
 return 100 * np.mean(top1_acc)

In [None]:
val_acc = eval(val_loader, net)
print(f"top-1 accuracy: {val_acc:.2f}%")

**Question 10**: Explain the obtained result. Was it expected?

In [None]:
# Weights are constantly updated through the training process
# At some point, performance on the validation set may decrease due to over-fitting
# That is why it is important to regularly evaluate the model on the validation set and to save the associated weights
# If it is computationally affordable, this can be done between each epoch

metrics = {
 "train_loss": list(),
 "train_accuracy": list(),
 "val_accuracy": list()
}
for epoch in range(num_epochs):
 train_loss, train_acc = train_epoch(train_loader, net, optimizer, loss_fn)
 metrics["train_loss"].append(train_loss)
 metrics["train_accuracy"].append(train_acc)
 print(f"Train epoch {epoch+1}: loss: {train_loss:.4f} ; top-1 accuracy: {train_acc:.2f}%")

 val_acc = eval(val_loader, net)
 if epoch == 0 or max(metrics["val_accuracy"]) < val_acc:
 torch.save(net.state_dict(), "best_model_weights.pth")
 metrics["val_accuracy"].append(val_acc)
 print(f"Eval epoch {epoch+1}: top-1 accuracy: {val_acc:.2f}%")


**Question 11**: How many back-propagations are performed per epoch? Why?

In [None]:
# Drawing training curves
def plot_curve(title, metric):
 plt.figure()
 plt.title(title)
 plt.plot(np.arange(len(metric))+1, metric)
 plt.show()

plot_curve("Training loss", metrics["train_loss"])
plot_curve("Training accuracy", metrics["train_accuracy"])
plot_curve("Validation accuracy", metrics["val_accuracy"])

In [None]:
# Retrieve the best weights
checkpoint = torch.load("best_model_weights.pth")
net.load_state_dict(checkpoint)

# Evaluate on the test set
test_acc = eval(test_loader, net)
print(f"top-1 accuracy: {test_acc:.2f}%")

# Your turn

## Exercise on transfer learning and fine-tuning

We will use the pre-trained model weights on digit 0 to 4, to initialize a new model which will perform classification over all the 10 digits.

- Generate new datasets with all the digits
- Adapt the architecture
- Compare performance when training
 - from scratch
 - from pre-trained weights (fine-tuning)
 - from pre-trained weights with conv layers frozen (transfer learning)

One can freeze a layer by switching to False the "requires_grad" attribute for all its parameters:
```param.requires_grad = False```

## Exercise on architecture

Improve the architecture to have better results

- Change the number of layers
- Change the kind of layers, the number of channels/neurons per layer
- Change the [Optimizer](https://pytorch.org/docs/stable/optim.html) / the learning rate
- Add regularization techniques (normalization, dropout)
- Change the activation functions (tanh, relu)