{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# TP2: Image-to-sequence (Handwritten Text Recognition)\n", "\n", "## Goals\n", "- Implement transformer-based model for an image-to-sequence task\n", "- Visualize attention maps\n", "\n", "\n", "We will use a modified version of the MNIST dataset with multiple digits within the same image. The goal is to train a model to recognize all digits in the correct order (from left to right)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Config used: torch==2.4.1 \n", "\n", "# Download the data\n", "!wget -nc https://people.irisa.fr/Denis.Coquenet/courses/content/M2-DLV/TP2/mnist_variable_len_1k.data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Making experiment reproducible\n", "import torch\n", "import numpy as np\n", "\n", "def set_deterministic():\n", " torch.manual_seed(0)\n", " torch.cuda.manual_seed(0)\n", " np.random.seed(0)\n", " torch.backends.cudnn.benchmark = False\n", " torch.backends.cudnn.deterministic = True\n", "\n", "set_deterministic()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Dataset\n", "\n", "This part focuses on defining the custom dataset with images of variable widths (containing between 1 and 5 concatenated MNIST digits). \n", "\n", "You have nothing to code here by yourself. \n", "\n", "TODO: read and understand how the dataset is handled. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import Dataset\n", "\n", "class MNISTDataset(Dataset):\n", " \"\"\"\n", " Custom dataset class to manipulate a specific version of the MNIST dataset in which samples consists in the concatenation of single digits from the original MNIST dataset.\n", " \"\"\"\n", "\n", " pad_label_value = None\n", "\n", " def __init__(self, filepath, set_name):\n", " self.set_name = set_name\n", " data = torch.load(filepath)\n", " self.classes = data[\"classes\"] # associate each class to an int\n", " self.samples = data[f\"{set_name}_samples\"] # the input images (containing between 1 and 5 digits)\n", " self.targets = data[f\"{set_name}_targets\"] # the expected output (numpy array of class indices)\n", " \n", " self.token_set = [c[0] for c in self.classes] + [\"\", \"\", \"

\"] # digits + begin / end / padding special tokens\n", " if MNISTDataset.pad_label_value is None:\n", " MNISTDataset.pad_label_value = self.token_set.index(\"

\")\n", "\n", " self.samples = [torch.tensor(s, dtype=torch.float).unsqueeze(0) for s in self.samples] # preprocess the inputs once and for all\n", " self.targets = [self.preformat_label(t) for t in self.targets] # preprocess the targets once and for all\n", "\n", " def __len__(self):\n", " \"\"\"\n", " Compute the number of samples in the dataset\n", " \"\"\"\n", " return len(self.samples)\n", "\n", " def __getitem__(self, idx):\n", " \"\"\"\n", " Return the sample at index idx (input and ground truth) as a dict\n", " \"\"\"\n", " return {\n", " \"input_img\": self.samples[idx], \n", " \"ground_truth\": self.targets[idx]\n", " }\n", " \n", " def preformat_label(self, label):\n", " \"\"\"\n", " Format the labels to process them in an image-to-sequence task.\n", " It casts the label from numpy to torch and adds a token before\n", " the label tokens, and a token after.\n", " \"\"\"\n", " new_label = torch.ones((len(label)+2), dtype=torch.long)\n", " new_label[0] = self.token_set.index(\"\")\n", " new_label[1:-1] = torch.tensor(label)\n", " new_label[-1] = self.token_set.index(\"\")\n", " return new_label\n", " \n", " def decode_tokens(self, tokens):\n", " \"\"\"\n", " tokens: iterable of int\n", " Return the string corresponding to the token sequence\n", " \"\"\"\n", " return [self.token_set[i] for i in tokens if i <= 10]\n", "\n", " @staticmethod\n", " def pad_images(img_list, padding_value):\n", " \"\"\"\n", " Function that puts some torch images together to process them as a mini-batch.\n", " Smaller images are padded.\n", " A torch image size is as follows (C, H, W)\n", " H: height, W: width, C: number of channels (1 for grayscale, 3 for RGB)\n", " \"\"\"\n", " num_imgs = len(img_list)\n", " channels = img_list[0].size(0)\n", " height = img_list[0].size(1)\n", " max_width = max([img.size(2) for img in img_list])\n", " batch_imgs = torch.full(size=(num_imgs, channels, height, max_width), \n", " fill_value=padding_value,\n", " dtype=img_list[0].dtype)\n", " for i, img in enumerate(img_list):\n", " batch_imgs[i, 0, :, :img.size(2)] = img\n", " return batch_imgs\n", " \n", " @classmethod\n", " def pad_labels(cls, label_list):\n", " \"\"\"\n", " Function that puts some labels together to process them as a mini-batch.\n", " Shorter labels are padded.\n", " batch_data: list of list of int\n", " \"\"\"\n", " num_labels = len(label_list)\n", " max_len = max([label.size(0) for label in label_list])\n", " batch_labels = torch.full(size=(num_labels, max_len),\n", " fill_value=cls.pad_label_value,\n", " dtype=label_list[0].dtype)\n", " for i, label in enumerate(label_list):\n", " batch_labels[i, :label.size(0)] = label \n", " return batch_labels\n", " \n", " @staticmethod\n", " def batch_samples(batch_data):\n", " \"\"\"\n", " Function that puts some samples together to process them as a mini-batch.\n", " batch_data: list of samples\n", " \"\"\"\n", " input_imgs = [data[\"input_img\"] for data in batch_data]\n", " ground_truths = [data[\"ground_truth\"] for data in batch_data]\n", " return {\n", " \"imgs\": MNISTDataset.pad_images(input_imgs, padding_value=0),\n", " \"labels\": MNISTDataset.pad_labels(ground_truths),\n", " \"original_widths\": [data[\"input_img\"].size(2) for data in batch_data],\n", " \"label_lengths\": [data[\"ground_truth\"].size(0) for data in batch_data],\n", " }" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "train_dataset = MNISTDataset(set_name=\"train\", filepath=\"./mnist_variable_len_1k.data\")\n", "train_dataloader = DataLoader(dataset=train_dataset, \n", " batch_size=512, \n", " shuffle=True,\n", " collate_fn=MNISTDataset.batch_samples)\n", "\n", "\n", "# Let's check the first mini-batch\n", "for batch_data in train_dataloader:\n", " x = batch_data[\"imgs\"]\n", " y = batch_data[\"labels\"]\n", " print(x.size(), y.size())\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's display some samples\n", "import matplotlib.pyplot as plt\n", "\n", "fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(10, 10))\n", "for i in range(5):\n", " for j in range(2):\n", " sample_index = 5*j+i\n", " axes[i][j].imshow(x[sample_index, 0], cmap=\"gray\")\n", " axes[i][j].set_title(f\"Target: {y[sample_index].numpy()}\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Architecture\n", "\n", "We will follow the architecture of the Document Attention Network, i.e., we will rely on an FCN encoder and a transformer decoder to handle this image-to-sequence task. \n", "\n", "We will follow a step-by-step strategy to design the network, successively focusing on the different components: \n", "- The fully convolutional network encoder \n", "- The transformer decoder \n", "- The positional encoding \n", "- Finally, merging all of these together \n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Encoder\n", "\n", "We will rely on a rather simple encoder made up of some convolutional layers. It aims at extracting 2D features from the input image. \n", "\n", "**TODO**: based on the convolutional layer configurations, compute the downsampling factor applied by the encoder, i.e., by which factor the height and the width of the input image are divided (replace the ??? in the following cell).\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from torch.nn import Module, InstanceNorm2d, Conv2d, ReLU\n", "\n", "class FCN_Encoder(Module):\n", " \"\"\"\n", " This encoder will take as input images of size (32, N*32), N being the number of digits in the image\n", " and output features of size (1, N)\n", " \"\"\"\n", "\n", " downsampling_factor = ???\n", "\n", " def __init__(self, out_dim):\n", " super().__init__()\n", " start_dim=32\n", " self.conv1 = Conv2d(in_channels=1, out_channels=start_dim, kernel_size=2, stride=2, padding=0) \n", " self.conv2 = Conv2d(in_channels=start_dim, out_channels=2*start_dim, kernel_size=2, stride=2, padding=0) \n", " self.conv3 = Conv2d(in_channels=2*start_dim, out_channels=4*start_dim, kernel_size=2, stride=2, padding=0) \n", " self.conv4 = Conv2d(in_channels=4*start_dim, out_channels=out_dim, kernel_size=2, stride=2, padding=0) \n", "\n", " self.norm1 = InstanceNorm2d(2*start_dim)\n", " self.norm2 = InstanceNorm2d(4*start_dim)\n", " self.activation = ReLU()\n", "\n", "\n", " def forward(self, x):\n", " x = self.activation(self.conv1(x))\n", " x = self.activation(self.conv2(x))\n", " x = self.norm1(x)\n", " x = self.activation(self.conv3(x))\n", " x = self.norm2(x)\n", " x = self.activation(self.conv4(x))\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Transformer Decoder\n", "\n", "The goal now is to implement a transformer decoder layer (we will only use one layer in the network). Its goal is to iteratively apply the transformer attention mechanisms to focus on a specific part of the input image. \n", "\n", "To do so, the transformer decoder relies on two kinds of attention: \n", "- the self-attention, which is applied between a set of tokens, and itself (queries, keys and values are from the same source). \n", "- the cross attention, which is applied between two sets of tokens (queries are from source tokens, and keys and values are from target tokens). \n", "\n", "When performing these attention mechanisms, it is important to mask some tokens to avoid accessing padding tokens, whether it is from the source (text tokens), or from the target (image tokens). \n", "\n", "Also, when applying teacher forcing, it is required to mask \"the future\", so the mask remains causal in the self-attention. \n", "\n", "TODO: complete/implement these mask functions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# L: length of the token sequence, B: mini-batch size, H: height of the image, W: width of the image\n", "\n", "class MaskHelper:\n", "\n", " @staticmethod\n", " def compute_tgt_padding_mask(target_text_sequence, target_lengths):\n", " \"\"\"\n", " Compute the mask of valid position for the target text sequence \n", " target_text_sequence: tensor of size (L, B, C)\n", " target_lengths: list of int\n", " True = padded location, False = real text token location\n", " Return tensor of size (B, L)\n", " \"\"\"\n", " L, B, C = target_text_sequence.size()\n", " mask = torch.ones((B, L), dtype=torch.bool)\n", " # TODO\n", " return mask\n", " \n", " @staticmethod\n", " def compute_image_padding_mask(batch_img_features, original_widths, downsampling_factor):\n", " \"\"\"\n", " Compute the mask of valid position for the image\n", " batch_img_features: tensor of size (B, C, H, W)\n", " original_widths: list of int\n", " downsampling_factor: int\n", " Must return a boolean flattened tensor of shape (B, H, W)\n", " True = padded location, False = real image location\n", " \"\"\"\n", " # TODO\n", " pass\n", " \n", " \n", " @staticmethod\n", " def compute_tgt_attn_mask(target_text_sequence):\n", " \"\"\"\n", " target_text_sequence: tensor of size (L, B, C)\n", " Return a mask tensor of size (L, L) indicating which text token can attend to which text tokens in self-attention\n", " mask[i, j] specifies if token i can attend to token j\n", " False = can attend, True = cannot attend\n", " Hint: you can use torch.triu function\n", " \"\"\"\n", " # TODO\n", " pass\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check the mask functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# visualizing tgt_attn_mask\n", "\n", "L = 5\n", "B = 4\n", "C = 256\n", "H = 32\n", "W = 160\n", "Hf, Wf = H // FCN_Encoder.downsampling_factor, W // FCN_Encoder.downsampling_factor\n", "\n", "target = torch.zeros((L, B, C)) \n", "target_len = torch.tensor([5, 2, 2, 3]) # (B, )\n", "images = torch.zeros((B, C, Hf, Wf))\n", "original_widths = torch.tensor([160, 64, 64, 96]) # (B, )\n", "\n", "tgt_attn_mask = MaskHelper.compute_tgt_attn_mask(target) # (L, L)\n", "\n", "print(tgt_attn_mask)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# visualizing tgt_pad_mask\n", "tgt_pad_mask = MaskHelper.compute_tgt_padding_mask(target, target_len) # (B, L)\n", "print(tgt_pad_mask)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# visualizing memory_pad_mask\n", "memory_pad_mask = MaskHelper.compute_image_padding_mask(images, original_widths, FCN_Encoder.downsampling_factor) # (B, Hf, Wf)\n", "torch.set_printoptions(profile=\"full\")\n", "print(memory_pad_mask)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Transformer decoder layer\n", "\n", "We now focus on the decoder layer itself. We will rely on the multi-head attention implementation of Pytorch. \n", "\n", "In the following cell, layers are already initialized with appropriate parameters. \n", "\n", "**TODO**: based on the Document Attention Network figure (gray box), implement the forward function. \n", "\n", "Documentation for the used layers can be found here: [MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), [Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear), [LayerNorm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#torch.nn.LayerNorm), [ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from torch.nn import MultiheadAttention, Linear, LayerNorm, ReLU\n", "\n", "class TransformerDecoder(Module):\n", "\n", " def __init__(self, embed_dim, num_heads=2, dim_feedforward=128):\n", " super().__init__()\n", " self.cross_attention = MultiheadAttention(embed_dim, \n", " num_heads=num_heads)\n", " \n", " self.norm1 = LayerNorm(embed_dim)\n", " self.self_attention = MultiheadAttention(embed_dim, \n", " num_heads=num_heads)\n", " self.norm2 = LayerNorm(embed_dim)\n", " \n", " # feedforward is defined as an MLP\n", " self.linear1 = Linear(embed_dim, dim_feedforward)\n", " self.activation = ReLU()\n", " self.linear2 = Linear(dim_feedforward, embed_dim)\n", "\n", " self.norm3 = LayerNorm(embed_dim)\n", "\n", " def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,\n", " tgt_key_padding_mask=None, memory_key_padding_mask=None):\n", " \"\"\"\n", " tgt: target text sequence (Lt, B, C)\n", " memory: flattened image features (Li, B, C)\n", " tgt_mask: mask for the target (which item can attend to which one, useful for teacher forcing) - bool tensor (Lt, Lt)\n", " memory_mask: mask for the image features (Li, Li) - not useful here as no self attention over the image\n", " tgt_key_padding_mask: mask to discard padded position in target - bool tensor (B, Lt)\n", " memory_key_padding_mask: mask to discard padded position in features - bool tensor (B, Li)\n", "\n", " Output: tuple including\n", " - the output sequence (Lt, B, C)\n", " - the attention weights of the cross-attention layer (B, Lt, Li)\n", " \"\"\"\n", " # TODO\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Positional encoding\n", "\n", "The multi-head attention is permutation-equivariant, so we need to inject positional information to the input sets. \n", "Each position is encoded as a positional embedding vector such that $\\text{PE}(p, k)$ corresponds to the $k^\\text{th}$ value of the positional embedding vector for position $p$.\n", "\n", "The original positional encoding (from transformer paper) is defined as follows:\n", "\\begin{equation*}\n", "\\begin{split}\n", " & \\mathrm{PE}(p, 2k) = \\sin(w_k \\cdot p)\\\\\n", " & \\mathrm{PE}(p, 2k+1) = \\cos(w_k \\cdot p)\\\\\n", "\\end{split}\n", "\\end{equation*} \n", "\n", "$\\forall k \\in [0, d/2]$, with $w_k = 1/10000^{2k/d}$ \n", "\n", "where $d$ is the dimension of the token embedding.\n", "\n", "**TODO**: complete the forward function of the following Module." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from torch.nn import Parameter\n", "\n", "class PositionalEncoding(Module):\n", "\n", " def __init__(self, dim):\n", " super().__init__()\n", " self.dim = dim\n", "\n", " # We use exponential properties for numerical stability\n", " # w_k = e^(-alog(b)) = e^(log(1/b^a)) = 1/(b^a)\n", " # with b = 10,000 and a = 2k/d\n", " a = torch.arange(0., dim, 2) / dim\n", " b = torch.tensor(10000.0)\n", " self.wk = Parameter(torch.exp(-a * torch.log(b)).unsqueeze(0), requires_grad=False) # (1, dim/2)\n", "\n", " def forward(self, indices):\n", " \"\"\"\n", " Input: indices (B, L) - e.g.: [0, 1, 2, 3, ..., L-1]\n", " Output: positional embedding (B, L, C)\n", " \"\"\"\n", " emb_indices = torch.zeros((indices.size(0), indices.size(1), self.dim), device=indices.device) # (B, L, C)\n", " emb_indices[:, :, ::2] = # TODO\n", " emb_indices[:, :, 1::2] = # TODO\n", " return emb_indices" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "index_sequence = torch.arange(0,500).unsqueeze(0) # (1, 500)\n", "C = 1000\n", "\n", "positional_encoding = PositionalEncoding(C) \n", "positional_emb = positional_encoding(index_sequence)[0] # (500, 1000)\n", "\n", "plt.figure()\n", "plt.imshow(positional_emb)\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Now, let's put all of these components together!\n", "\n", "The following cell implements the end-to-end image-to-sequence model. It takes as input the image, and outputs the predicted text sequence. \n", "\n", "TODO: implement the *decode* function." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from torch.nn import Embedding\n", "\n", "class I2SModel(Module):\n", "\n", " \"\"\"\n", " The whole image-to-sequence model.\n", " \"\"\"\n", "\n", " def __init__(self, token_set, embed_dim=128, max_preds=6):\n", " super().__init__()\n", " self.max_preds = max_preds # set a maximum number of predictions to avoid infinite loops\n", " self.token_set = token_set # list of all tokens (digits + special tokens , and

)\n", "\n", " self.encoder = FCN_Encoder(embed_dim) # to extract image features\n", " self.decoder = TransformerDecoder(embed_dim=embed_dim) # to iteratively focus on a subpart of the features\n", " self.positional_encoding = PositionalEncoding(embed_dim) # to inject positional information in the attention process\n", " self.emb_layer = Embedding(num_embeddings=len(token_set), embedding_dim=embed_dim) # to associate a vector embedding to each text token\n", " self.decision_layer = Linear(in_features=embed_dim, out_features=len(token_set)-2) # to make token predictions (-2: no need to predict and

)\n", "\n", "\n", " def forward(self, x, original_widths, tgt=None, tgt_lengths=None):\n", " \"\"\"\n", " x: a mini-batch input tensor of images of variable widths (B, C, H, W)\n", " original_widths: a list of int (useful for masking)\n", " tgt: a target mini-batch output tensor of token indices (for training only = teacher forcing) (B, L)\n", " tgt_lengths: a list of int (target sequences can be of variable lengths, for training only)\n", " \"\"\"\n", "\n", " # Encoder part (performed only once)\n", " features = self.encoder(x) # extracting features from image (B, dim, H/16, W/16)\n", " flattened_features = torch.flatten(features, start_dim=2, end_dim=3).permute(2, 0, 1) # flatten for transformer requirements (Li, B, dim) with Li=(H/16)*(W/16)\n", " Li, B, C = flattened_features.size()\n", " img_position = self.generate_positional_embedding(B, Li) # generate positional embedding\n", " flattened_features = flattened_features + img_position # adding positional embedding\n", " img_pad_mask = MaskHelper.compute_image_padding_mask(features, \n", " original_widths, \n", " downsampling_factor=self.encoder.downsampling_factor) # generate mask for features\n", " img_pad_mask = torch.flatten(img_pad_mask, start_dim=1, end_dim=2) # flatten mask for transformer requirements (B, Li)\n", " \n", " # Decoder part (iterative)\n", " # The iterative prediction process can be parallelized at training time through teacher forcing, but not at evaluation time\n", " if tgt is None:\n", " # eval mode \n", " B = x.size(0)\n", " tgt = torch.full((B, 1), dtype=torch.long, fill_value=self.token_set.index(\"\")) # start with the token to initiate the decoding stage\n", " tgt_lengths = [1 for _ in range(B)] # keep track for prediction lengths\n", " for i in range(1, self.max_preds+1): # repeat the decoding step x times\n", " output, attn_weights = self.decode(tgt, tgt_lengths, flattened_features, img_pad_mask) # perform one decoding step\n", " new_pred = torch.argmax(output[:, -1], dim=-1, keepdim=True) # get the most probable token\n", " tgt = torch.cat([tgt, new_pred], dim=-1) # add this most probable token to the target/query sequence\n", " for j in range(B): # check for each sample in the mini-batch\n", " if tgt_lengths[j] == i and new_pred[j].item() != self.token_set.index(\"\"): # increase prediction length only if sequence is not already ended (a token has been predicted)\n", " tgt_lengths[j] += 1\n", " \n", " # Reshape attention weights to get back in 2 dimensions\n", " B, _, H, W = x.size()\n", " _, Lt, _ = attn_weights.size() # (B, Lt, Li)\n", " attn_weights = attn_weights.reshape(B, Lt, H//self.encoder.downsampling_factor, W//self.encoder.downsampling_factor) # (B, Lt, H/16, W/16)\n", "\n", " return tgt, tgt_lengths, attn_weights\n", " else: \n", " # Training mode\n", " # All predictions are performed at once based on the ground truth, i.e., we use the ground truth as if it was what we predicted so far\n", " return self.decode(tgt, tgt_lengths, flattened_features, img_pad_mask)\n", " \n", " def generate_positional_embedding(self, batch_size, length):\n", " \"\"\"\n", " Compute a positional embedding tensor for position starting from 0 to length-1\n", " The same positions are used for each sample in the mini-batch.\n", " \"\"\"\n", " positions = torch.arange(0, length) # [0,1,2,..., length-1]\n", " batch_positions = torch.repeat_interleave(positions.unsqueeze(1), repeats=batch_size, dim=1) # (length, batch_size)\n", " batched_position_emb = self.positional_encoding(batch_positions) # (length, batch_size, dim)\n", " return batched_position_emb\n", " \n", "\n", " def decode(self, tgt, tgt_lengths, features, img_pad_mask):\n", " \"\"\"\n", " tgt: tensor of token indices (B, Lt)\n", " tgt_lengths: list of int\n", " features: flattened image features extracted from the encoder (Li, B, C)\n", " img_pad_mask: bool mask tensor for the flattened features (B, Li)\n", "\n", " Returns a tuple:\n", " - the predictions of size (B, Lt, N)\n", " - the attention weights of size (B, Lt, Li)\n", " \"\"\"\n", "\n", " # Token indices to token embeddings\n", " # TODO\n", "\n", " # Adding positional encoding to token embedding\n", " # TODO\n", "\n", " # Compute tgt masks\n", " # TODO\n", "\n", " # Apply transformer decoder layer\n", " # TODO\n", "\n", " # Prediction\n", " # TODO\n", "\n", " return output, attn_weights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training\n", "\n", "We now have a dataset and a network. Last stage is to train the network on that dataset! \n", "\n", "TODO: implement the *accuracy* function to compute the top-1 accuracy metric, carefully ignoring padding tokens. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.optim import Adam\n", "from torch.nn import CrossEntropyLoss\n", "from tqdm import tqdm\n", "import numpy as np\n", "\n", "def accuracy(x, y, padding_value):\n", " \"\"\"\n", " Compute the top-1 accuracy between predictions x (B, L, N) and ground truth y (L, N) \n", " B: mini-batch size; L: target sequence length; N: token set size\n", " Output: float\n", " \"\"\"\n", " # TODO\n", " pass\n", "\n", "def train(model, dataloader, num_epochs=40):\n", " \"\"\"\n", " Training loop\n", " \"\"\"\n", " optimizer=Adam(model.parameters(), lr=0.001)\n", " loss_fn = CrossEntropyLoss(ignore_index=MNISTDataset.pad_label_value)\n", " model.train()\n", " for i in range(num_epochs):\n", " losses = []\n", " accuracies = []\n", " progress_bar = tqdm(dataloader)\n", " for batch_data in progress_bar:\n", " # remove previous gradients (from previous mini-batch)\n", " optimizer.zero_grad()\n", "\n", " # get data of current mini-batch\n", " x = batch_data[\"imgs\"]\n", " y = batch_data[\"labels\"]\n", " widths = batch_data[\"original_widths\"]\n", " lens = batch_data[\"label_lengths\"]\n", "\n", " # forward pass\n", " output, _ = model(x, original_widths=widths, tgt=y[:, :-1], tgt_lengths=lens)\n", "\n", " # loss computation\n", " loss = loss_fn(output.permute(0, 2, 1), y[:, 1:])\n", "\n", " # backward pass\n", " loss.backward()\n", "\n", " # weight update\n", " optimizer.step()\n", " \n", " # display loss and metric\n", " losses.append(loss.item())\n", " accuracies.append(accuracy(output, y[:, 1:], MNISTDataset.pad_label_value))\n", " progress_bar.set_description(f\"EPOCH {i} - loss: {np.mean(losses):.4f} ; acc : {np.mean(accuracies)*100:.2f}\")\n", "\n", "model = I2SModel(train_dataset.token_set)\n", "train(model, train_dataloader)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Prediction\n", "\n", "The network should reach nearly perfect results on the training set. Let's make a prediction." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def display_image(x):\n", " plt.subplots()\n", " plt.imshow(x[0], cmap=\"gray\")\n", " plt.show()\n", "\n", "def predict(model, x):\n", " model.eval()\n", " x = x.unsqueeze(0)\n", " widths = [x.size(3)]\n", " output, tgt_lengths, attn_weights = model(x, widths)\n", " output = output[0][:tgt_lengths[0]+1]\n", " token_pred = train_dataset.decode_tokens(output)\n", " str_pred = \"\".join([p for p in token_pred if p != \"\"])\n", " return token_pred, str_pred, attn_weights[0]\n", "\n", "test_dataset = MNISTDataset(set_name=\"test\", filepath=\"./mnist_variable_len_1k.data\")\n", "data = test_dataset[3]\n", "x = data[\"input_img\"]\n", "display_image(x)\n", "token_pred, str_pred, attn_weights = predict(model, x)\n", "print(f\"Prediction: {str_pred}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Visualization\n", "\n", "Now, let's analyze the attention weights, iteration per iteration. \n", "\n", "Try with some examples. What kinds of error does the model make?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import cv2\n", "\n", "def show_iterations(img, weights, preds):\n", " \"\"\"\n", " img: Tensor of size (1, H, W)\n", " weights: Tensor of size (L, H, W)\n", " preds: String of length L\n", " \"\"\" \n", " img = img.detach().numpy()[0]\n", " weights = [cv2.resize(w.detach().numpy(), (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) for w in weights]\n", " preds = list(preds)\n", " num_preds = len(preds)\n", " fig, axes = plt.subplots(nrows=num_preds, figsize=(5*num_preds,10))\n", "\n", " if not isinstance(axes, np.ndarray):\n", " axes = [axes]\n", "\n", " for i in range(num_preds):\n", " axes[i].imshow(img, alpha=0.5)\n", " attn_map = axes[i].imshow(weights[i], cmap='viridis', alpha=0.6)\n", " axes[i].set_title(f\"Prediction #{i+1}: '{preds[i]}'\")\n", " plt.colorbar(attn_map)\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "show_iterations(x, attn_weights, token_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# To go further: attention implementation\n", "\n", "Reminder: the multi-head attention is defined as follows:\n", "\n", "$$Q_1 = XW^Q_1 \\quad \\ldots \\quad Q_h = XW^Q_h$$\n", "$$K_1 = YW^K_1 \\quad \\ldots \\quad K_h = YW^K_h$$\n", "$$V_1 = YW^V_1 \\quad \\ldots \\quad V_h = YW^V_h$$\n", "\n", "$$O_1 = \\text{softmax} \\left(\\frac{Q_1K_1^T}{\\sqrt{d}}\\right) V_1 \\quad \\ldots \\quad O_h = \\text{softmax}\\left(\\frac{Q_hK_h^T}{\\sqrt{d}}\\right) V_h$$\n", "\n", "$$Y = \\text{concat}(O_1, \\ldots, O_h)W^O$$\n", "\n", "$d$ refers to the number of dimensions associated to one head.\n", "\n", "where $X=Y$ for self-attention. \n", "\n", "It has to be noted that the query/key/value projections of all heads can be computed in parallel among the same densely-connected layer for computation efficiency.\n", "\n", "TODO: complete the following class which implements this attention mechanism." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "from torch.nn import Linear, Module\n", "\n", "\n", "class TransformerAttention(Module):\n", "\n", " def __init__(self, dim, num_heads):\n", " \"\"\"\n", " dim: total number of dimensions C\n", " num_heads: number of attention heads h\n", " Each head must have the same number of dimensions d\n", " The number of dimensions is preserved from input to output\n", " \"\"\"\n", " super().__init__()\n", " assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n", " self.num_heads = num_heads\n", " self.head_dim = dim // num_heads\n", " self.scale_factor = self.head_dim ** -0.5\n", "\n", " # head projections can be merged in a single Linear to parallelize computations\n", " self.q_proj = Linear(dim, dim, bias=False)\n", " self.k_proj = Linear(dim, dim, bias=False)\n", " self.v_proj = Linear(dim, dim, bias=False)\n", "\n", " self.out_proj = Linear(dim, dim, bias=False) # W^o\n", "\n", " def forward(self, query_seq, key_seq, value_seq, key_padding_mask=None, attn_mask=None):\n", " \"\"\"\n", " Input: query_seq (B, L_t, C)\n", " key_seq (B, L_s, C)\n", " value_seq (B, L_s, C)\n", " key_padding_mask (B, L_s) = which token to consider in source, to ignore padding\n", " attn_mask (L_t, L_s) = which source token to consider for each target token, to ignore future prediction with teacher forcing\n", " Output: (B, L_t, C)\n", " \"\"\"\n", " B, target_len, C = query_seq.size()\n", " source_len = key_seq.size(1)\n", "\n", " #1 Projection must lead to Q (B, L_t, C), K (B, L_s, C), V (B, L_s, C)\n", " q = # TODO\n", " k = # TODO\n", " v = # TODO\n", "\n", " #2 We will treat all heads in parallel through batch: it must lead to Q (B*h, L_t, d), K (B*h, L_s, d), V (B*h, L_s, d)\n", " #Be careful with data order, do not mix representations of different tokens together\n", " # Hint: you will need permute and reshape functions\n", " # TODO\n", "\n", " #3 Compute scores S=QK/srqt(d), leading to S (B*h, L_t, L_s)\n", " # S[i, j, k] = how well token j matches with token k with respect to sample/head i\n", " # TODO\n", "\n", " #4 We need to take into account padded sequences\n", " # Score for padding token must be set to -Inf to be ignored through the softmax operation\n", " # Hint: use masked_fill function\n", " if key_padding_mask is not None:\n", " pass \n", " # TODO\n", "\n", " #5 We need to take into account the causality of the decoder part\n", " #The model should not use tokens that have not yet been processed\n", " if attn_mask is not None:\n", " pass\n", " # TODO\n", "\n", " #6 Apply softmax to get attention weights alpha (B*h, L_t, L_s)\n", " # TODO\n", "\n", " #7 Computation of alpha·V (B*h, L_t, d)\n", " # TODO\n", "\n", " #8 Dissociate heads for projection\n", " # TODO\n", " \n", " return attn_weights, x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check that it provides the same behavior that the pytorch implementation" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "from torch.nn import MultiheadAttention\n", "\n", "B = 2\n", "L = 10\n", "C = 64\n", "H = 2\n", "\n", "# Initialize modules\n", "custom_attn = TransformerAttention(dim=C, num_heads=H)\n", "torch_attn = MultiheadAttention(embed_dim=C, num_heads=H, bias=False, add_bias_kv=False, batch_first=True)\n", "\n", "# Ensure they have same weights\n", "torch_state_dict = torch_attn.state_dict()\n", "custom_state_dict = {\n", " \"q_proj.weight\": torch_state_dict[\"in_proj_weight\"][:C],\n", " \"k_proj.weight\": torch_state_dict[\"in_proj_weight\"][C:2*C],\n", " \"v_proj.weight\": torch_state_dict[\"in_proj_weight\"][2*C:],\n", " \"out_proj.weight\": torch_state_dict[\"out_proj.weight\"]\n", "}\n", "custom_attn.load_state_dict(custom_state_dict)\n", "\n", "# Generate fake example\n", "fake_sample = torch.randn((B, L, C), dtype=torch.float)\n", "padding_mask = torch.zeros((B, L), dtype=torch.bool)\n", "padding_mask[0, -2:] = True\n", "mask = torch.zeros((L, L), dtype=torch.bool)\n", "mask[0, 7] = mask[1, 1] = True\n", "\n", "# forward pass\n", "out_torch, weights_torch = torch_attn(fake_sample, fake_sample, fake_sample, key_padding_mask=padding_mask, attn_mask=mask)\n", "weights_custom, out_custom = custom_attn(fake_sample, fake_sample, fake_sample, key_padding_mask=padding_mask, attn_mask=mask)\n", "\n", "# check equal\n", "assert torch.allclose(out_torch, out_custom, atol=1e-6)" ] } ], "metadata": { "kernelspec": { "display_name": "course_DLV", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }