{ "cells": [ { "cell_type": "markdown", "id": "c6fb63ad96f4741b", "metadata": { "collapsed": false, "id": "c6fb63ad96f4741b" }, "source": [ "# TP1: Pytorch Basics\n", "\n", "## Goals\n", "- Understand implementation of a classification task\n", " - Data formatting and manipulation\n", " - Architecture implementation\n", " - Training/evaluation\n", " - Load/save models\n", "\n", "- Adapt an architecture to a new dataset (transfer learning / fine-tuning)\n", "\n", "\n", "In a first part, implementation code is given for the training of a LeNet-like architecture used for a classification task using the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) (only digits 0 to 4 are considered).\n", "In a second part, you have to adapt this implementation for two use cases: transfer learning and fine-tuning. The goal is to adapt to a new set of classes: from digits 0 to 4, to digits 0 to 10." ] }, { "cell_type": "markdown", "id": "15c6f7a3b17034a7", "metadata": { "collapsed": false, "id": "15c6f7a3b17034a7" }, "source": [ "## I - Formatting the dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "c7b16500a4c84b8a", "metadata": { "id": "c7b16500a4c84b8a" }, "outputs": [], "source": [ "import torch\n", "from torchvision.datasets import MNIST\n", "from torch.utils.data import DataLoader, Dataset" ] }, { "cell_type": "markdown", "id": "a97837b2b8e5d5a", "metadata": { "collapsed": false, "id": "a97837b2b8e5d5a" }, "source": [ "Documentation on [Datasets and DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) \n", "First, we create a custom Dataset class to handle the MNIST dataset. \n", "A Dataset object must include two functions: \n", "```__len__()```: which gives the number of samples in the dataset \n", "```__getitem__(i)```: which return necessary information about the ith sample (generally input and expected output)\n", "\n", "The original training set (60,000 images) is split into a new training set (50,000 images) and a validation set (10,000 images)\n", "\n", "**Question 1**: Why do we add a validation split, in addition to the test set?\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b5c0427ce3485002", "metadata": { "id": "b5c0427ce3485002" }, "outputs": [], "source": [ "class MNISTDataset(Dataset):\n", " def __init__(self, set_name, labels=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), num_samples_per_label=None):\n", " self.set_name = set_name\n", " self.labels = labels\n", " self.mnist = MNIST(root=\"./cache\", train=set_name in (\"train\", \"val\"), download=True)\n", " self.samples = self.format_samples(num_samples_per_label)\n", "\n", " def format_samples(self, num_samples_per_label=None):\n", " samples = list()\n", " match self.set_name:\n", " case \"train\":\n", " indices = list(range(5000, 55000))\n", " case \"val\":\n", " indices = list(range(0, 5000)) + list(range(55000, 60000))\n", " case \"test\":\n", " indices = list(range(0, 10000))\n", "\n", " num_samples_per_label_dict = dict()\n", " for label in self.labels:\n", " num_samples_per_label_dict[label] = 0\n", "\n", " for i in indices:\n", " label = int(self.mnist.targets[i])\n", " if label not in self.labels:\n", " continue\n", " if num_samples_per_label is not None and num_samples_per_label_dict[label] >= num_samples_per_label:\n", " continue\n", " image = self.mnist.data[i].to(torch.float).unsqueeze(0)\n", " samples.append({\n", " \"image\": image,\n", " \"label\": self.mnist.targets[i]\n", " })\n", " num_samples_per_label_dict[label] += 1\n", " return samples\n", "\n", " def __len__(self):\n", " return len(self.samples)\n", "\n", " def __getitem__(self, idx):\n", " return self.samples[idx][\"image\"], self.samples[idx][\"label\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "82443ce0eac82a4e", "metadata": { "id": "82443ce0eac82a4e" }, "outputs": [], "source": [ "labels = (0, 1, 2, 3, 4)\n", "\n", "# Dataset instantiations\n", "train_dataset = MNISTDataset(set_name=\"train\", labels=labels)\n", "val_dataset = MNISTDataset(set_name=\"val\", labels=labels)\n", "test_dataset = MNISTDataset(set_name=\"test\", labels=labels)" ] }, { "cell_type": "code", "execution_count": null, "id": "410a8b3734fee5b5", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "410a8b3734fee5b5", "outputId": "fa37dc77-61a1-40d2-ebcf-97240e4e5d35" }, "outputs": [], "source": [ "# len(dataset) is a shortcut for dataset.__len__()\n", "print(f\"# samples for training: {len(train_dataset)}\")\n", "print(f\"# samples for validation: {len(val_dataset)}\")\n", "print(f\"# samples for test: {len(test_dataset)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "77612143075a6224", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 504 }, "id": "77612143075a6224", "outputId": "916bfa72-19fb-479e-9a3c-271bac36d1c9" }, "outputs": [], "source": [ "# One can use matplotlib package to show the first training sample\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "# dataset[i] is a shortcut for dataset.__getitem__(i)\n", "image, label = train_dataset[0]\n", "plt.figure()\n", "plt.axis('off')\n", "# Permutation required to go from Pytorch format (C, H, W) to matplotlib format (H, W, C)\n", "plt.imshow(image.permute(1, 2, 0), cmap=\"gray\")\n", "plt.tight_layout()\n", "plt.show()\n", "print(f\"Target class: {label}\")" ] }, { "cell_type": "markdown", "id": "9ee35f67e3fca870", "metadata": { "collapsed": false, "id": "9ee35f67e3fca870" }, "source": [ "## II - Architecture implementation\n", "\n", "We will use a modified version of LeNet-5 architecture, which takes as input grayscaled images of size (1, 28, 28). \n", "Documentation for [Layers and Losses](https://pytorch.org/docs/stable/nn.html)" ] }, { "cell_type": "code", "execution_count": null, "id": "2c379364452ee2ef", "metadata": { "id": "2c379364452ee2ef" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "class LeNet(nn.Module):\n", " def __init__(self):\n", " super(LeNet, self).__init__()\n", " self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=3, stride=1, padding=1)\n", " self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)\n", " self.fc1 = nn.Linear(in_features=400, out_features=1024)\n", " self.fc2 = nn.Linear(in_features=1024, out_features=84)\n", " self.fc3 = nn.Linear(in_features=84, out_features=5)\n", "\n", " self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)\n", "\n", " @property\n", " def device(self):\n", " return next(self.parameters()).device\n", "\n", " def forward(self, x):\n", " out = torch.tanh(self.conv1(x))\n", " out = self.max_pool(out)\n", " out = torch.tanh(self.conv2(out))\n", " out = self.max_pool(out)\n", " out = out.reshape(out.size(0), -1) # flatten the representation (from 2D image to 1D vector)\n", " out = torch.tanh(self.fc1(out))\n", " out = torch.tanh(self.fc2(out))\n", " out = self.fc3(out)\n", " return out" ] }, { "cell_type": "markdown", "id": "c3269f47075c9ab1", "metadata": { "collapsed": false, "id": "c3269f47075c9ab1" }, "source": [ "**Question 2**: How many kernels are used in conv1 and conv2?\n", "\n", "**Question 3**: What is the decision layer?" ] }, { "cell_type": "code", "execution_count": null, "id": "de2ffbb7ae1d968a", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "de2ffbb7ae1d968a", "outputId": "fec445c1-70ea-470e-b0f6-a36b7e51bac8" }, "outputs": [], "source": [ "# Check if GPU available\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "# Instantiation of the model\n", "net = LeNet().to(device=device)\n", "\n", "from torchsummary import summary\n", "summary(net, (1, 28, 28))" ] }, { "cell_type": "markdown", "id": "229715ffa0e0ee13", "metadata": { "collapsed": false, "id": "229715ffa0e0ee13" }, "source": [ "**Question 4**: What is the meaning of the \"-1\" value in the output shape?\n", "\n", "**Question 5**: Which layers are parametric?" ] }, { "cell_type": "code", "execution_count": null, "id": "bbf9b346b264aa61", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bbf9b346b264aa61", "outputId": "e2ae8dd0-9ece-4bbf-9401-4c7f4ccc4d55" }, "outputs": [], "source": [ "print(list(net.conv1.named_parameters()))" ] }, { "cell_type": "markdown", "id": "d8ee95d8df3a756f", "metadata": { "collapsed": false, "id": "d8ee95d8df3a756f" }, "source": [ "**Question 6**: How many tensors of weights are stored for the conv1 layer?" ] }, { "cell_type": "code", "execution_count": null, "id": "d7fe143a4260442", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "d7fe143a4260442", "outputId": "43a11093-d97e-43f8-f9b9-4ed9cf55bddc" }, "outputs": [], "source": [ "image, label = test_dataset[0]\n", "image = image.to(device)\n", "\n", "# inference (forward pass)\n", "net.eval()\n", "with torch.inference_mode():\n", " output = net(image.unsqueeze(0))\n", " print(output, output.size())\n", "\n", " output = torch.softmax(output, dim=1)\n", " print(output, output.size())" ] }, { "cell_type": "markdown", "id": "168cc2392cbcbba2", "metadata": { "collapsed": false, "id": "168cc2392cbcbba2" }, "source": [ "**Question 7**: What is the goal of the softmax function?\n", "\n", "**Question 8**: What is the meaning of the obtained values?\n", "\n", "**Question 9**: What would be the predicted class?" ] }, { "cell_type": "markdown", "id": "747076a30e87d564", "metadata": { "collapsed": false, "id": "747076a30e87d564" }, "source": [ "## III - Training\n", "\n", "Training consists in iteratively training on the training set and evaluating on the validation set to see the evolution of the performance on unseen data. One must define appropriate function for training and evaluation. The main difference lies in the computation of loss, gradients and backpropagation, which is only performed at training time." ] }, { "cell_type": "code", "execution_count": null, "id": "17f0e584b965ba7d", "metadata": { "id": "17f0e584b965ba7d" }, "outputs": [], "source": [ "import numpy as np\n", "from tqdm import tqdm\n", "\n", "# Training hyperparameters\n", "num_epochs = 25\n", "batch_size = 1000\n", "learning_rate = 0.01\n", "optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)\n", "\n", "# Loss\n", "loss_fn = torch.nn.CrossEntropyLoss()\n", "\n", "# A dataloader is an iterator over the dataset. It is useful to perform an epoch.\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "dc6ce40acf07e239", "metadata": { "id": "dc6ce40acf07e239" }, "outputs": [], "source": [ "# function to compute Top-1 accuracy metric\n", "def compute_top1_acc(prediction, ground_truth):\n", " # prediction (B, N), ground_truth (B)\n", " best_prediction = torch.argmax(prediction, dim=1)\n", " return torch.mean(torch.eq(best_prediction, ground_truth), dtype=torch.float)" ] }, { "cell_type": "code", "execution_count": null, "id": "5debe32bd5b5df6c", "metadata": { "id": "5debe32bd5b5df6c" }, "outputs": [], "source": [ "# function to perform batch-gradient descent\n", "def train_batch(x, y, net, optimizer, loss_function):\n", " x, y = x.to(net.device), y.to(net.device) # put model weights and inputs on the same device (CPU/GPU)\n", " optimizer.zero_grad() # zero the gradient buffers\n", " output = net(x) # inference (forward-pass)\n", " loss = loss_function(output, y) # compute loss\n", " loss.backward() # compute gradients (backward-pass)\n", " optimizer.step() # apply gradients (backward-pass)\n", " top1_acc = compute_top1_acc(output, y) # compute metric\n", " return loss.item(), top1_acc.item()" ] }, { "cell_type": "code", "execution_count": null, "id": "af667fe421895cbd", "metadata": { "id": "af667fe421895cbd" }, "outputs": [], "source": [ "# function to train over all the training samples through batch gradient descent\n", "def train_epoch(dataloader, net, optimizer, loss_fn):\n", " epoch_loss = list()\n", " epoch_top1_acc = list()\n", " net.train()\n", " progress_bar = tqdm(dataloader)\n", " for x, y in progress_bar:\n", " progress_bar.set_description(\"Training\")\n", " batch_loss, batch_top1_acc = train_batch(x, y, net, optimizer, loss_fn)\n", " epoch_loss.append(batch_loss)\n", " epoch_top1_acc.append(batch_top1_acc)\n", " current_loss = np.mean(epoch_loss)\n", " current_top1_acc = 100 * np.mean(epoch_top1_acc)\n", " return current_loss, current_top1_acc" ] }, { "cell_type": "code", "execution_count": null, "id": "98cb4ed04cb2e55d", "metadata": { "id": "98cb4ed04cb2e55d" }, "outputs": [], "source": [ "# function to evaluate performance over a batch (forward pass only)\n", "def eval_batch(x, y, net):\n", " x, y = x.to(net.device), y.to(net.device)\n", " output = net(x)\n", " top1_acc = compute_top1_acc(output, y)\n", " return top1_acc.item()" ] }, { "cell_type": "code", "execution_count": null, "id": "3d0cf9b951e5f578", "metadata": { "id": "3d0cf9b951e5f578" }, "outputs": [], "source": [ "# function to evaluate performance over a whole set (forward pass only)\n", "def eval(dataloader, net):\n", " top1_acc = list()\n", " net.eval()\n", " with torch.inference_mode(): # prevent tracking gradient-related operation\n", " for x, y in dataloader:\n", " batch_top1_acc = eval_batch(x, y, net)\n", " top1_acc.append(batch_top1_acc)\n", " return 100 * np.mean(top1_acc)" ] }, { "cell_type": "code", "execution_count": null, "id": "9265b4ffb0d44b92", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9265b4ffb0d44b92", "outputId": "d5a1acc9-2040-4ee1-e87e-68ba08e3a00a" }, "outputs": [], "source": [ "val_acc = eval(val_loader, net)\n", "print(f\"top-1 accuracy: {val_acc:.2f}%\")" ] }, { "cell_type": "markdown", "id": "b1c976db3fa2ece7", "metadata": { "collapsed": false, "id": "b1c976db3fa2ece7" }, "source": [ "**Question 10**: Explain the obtained result. Was it expected?" ] }, { "cell_type": "code", "execution_count": null, "id": "61459923e64fe852", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "61459923e64fe852", "outputId": "212544cb-3cc4-4c03-f442-52398053cf42" }, "outputs": [], "source": [ "# Weights are constantly updated through the training process\n", "# At some point, performance on the validation set may decrease due to over-fitting\n", "# That is why it is important to regularly evaluate the model on the validation set and to save the associated weights\n", "# If it is computationally affordable, this can be done between each epoch\n", "\n", "metrics = {\n", " \"train_loss\": list(),\n", " \"train_accuracy\": list(),\n", " \"val_accuracy\": list()\n", "}\n", "for epoch in range(num_epochs):\n", " train_loss, train_acc = train_epoch(train_loader, net, optimizer, loss_fn)\n", " metrics[\"train_loss\"].append(train_loss)\n", " metrics[\"train_accuracy\"].append(train_acc)\n", " print(f\"Train epoch {epoch+1}: loss: {train_loss:.4f} ; top-1 accuracy: {train_acc:.2f}%\")\n", "\n", " val_acc = eval(val_loader, net)\n", " if epoch == 0 or max(metrics[\"val_accuracy\"]) < val_acc:\n", " torch.save(net.state_dict(), \"best_model_weights.pth\")\n", " metrics[\"val_accuracy\"].append(val_acc)\n", " print(f\"Eval epoch {epoch+1}: top-1 accuracy: {val_acc:.2f}%\")\n" ] }, { "cell_type": "markdown", "id": "8debb01bf3dc6d12", "metadata": { "collapsed": false, "id": "8debb01bf3dc6d12" }, "source": [ "**Question 11**: How many back-propagations are performed per epoch? Why?" ] }, { "cell_type": "code", "execution_count": null, "id": "29dd277f02742cfd", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "29dd277f02742cfd", "outputId": "2d31c28f-d6c6-4456-ebf3-b72cc173dee9" }, "outputs": [], "source": [ "# Drawing training curves\n", "def plot_curve(title, metric):\n", " plt.figure()\n", " plt.title(title)\n", " plt.plot(np.arange(len(metric))+1, metric)\n", " plt.show()\n", "\n", "plot_curve(\"Training loss\", metrics[\"train_loss\"])\n", "plot_curve(\"Training accuracy\", metrics[\"train_accuracy\"])\n", "plot_curve(\"Validation accuracy\", metrics[\"val_accuracy\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "89c2e9a94fdc2437", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "89c2e9a94fdc2437", "outputId": "a7028cc2-fe78-4235-a751-eaee2c0d3fed" }, "outputs": [], "source": [ "# Retrieve the best weights\n", "checkpoint = torch.load(\"best_model_weights.pth\")\n", "net.load_state_dict(checkpoint)\n", "\n", "# Evaluate on the test set\n", "test_acc = eval(test_loader, net)\n", "print(f\"top-1 accuracy: {test_acc:.2f}%\")" ] }, { "cell_type": "markdown", "id": "4f4fdf497f0923e", "metadata": { "collapsed": false, "id": "4f4fdf497f0923e" }, "source": [ "# Your turn\n", "\n", "## Exercise on transfer learning and fine-tuning\n", "\n", "We will use the pre-trained model weights on digit 0 to 4, to initialize a new model which will perform classification over all the 10 digits.\n", "\n", "- Generate new datasets with all the digits\n", "- Adapt the architecture\n", "- Compare performance when training\n", " - from scratch\n", " - from pre-trained weights (fine-tuning)\n", " - from pre-trained weights with conv layers frozen (transfer learning)\n", "\n", "One can freeze a layer by switching to False the \"requires_grad\" attribute for all its parameters:\n", "```param.requires_grad = False```" ] }, { "cell_type": "markdown", "id": "7c51046925f48bd0", "metadata": { "collapsed": false, "id": "7c51046925f48bd0" }, "source": [ "## Exercise on architecture\n", "\n", "Improve the architecture to have better results\n", "\n", "- Change the number of layers\n", "- Change the kind of layers, the number of channels/neurons per layer\n", "- Change the [Optimizer](https://pytorch.org/docs/stable/optim.html) / the learning rate\n", "- Add regularization techniques (normalization, dropout)\n", "- Change the activation functions (tanh, relu)" ] }, { "cell_type": "markdown", "id": "79c7ab4ab35d3cf", "metadata": { "collapsed": false, "id": "79c7ab4ab35d3cf" }, "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }