Я построил Vision Transformer с нуля — и научил его обращать внимание

от автора

Vision Transformer (ViT) — это архитектура, которая буквально произвела революцию в том, как машины «видят» мир.

В этой статье я не просто объясню, что такое ViT — я покажу вам, как создать эту магию своими руками, шаг за шагом, даже если вы никогда раньше не работали с трансформерами для задач с изображениями.

Для начала давайте взглянем на архитектуру Vision Transformer:

Vision Transformer architecture

Vision Transformer architecture

Мы напишем код полностью с нуля, а затем обучим модель на датасете CIFAR-10.

Давайте начнём с реализации Patch Embedding:

class PatchEmbedding(nn.Module):     def __init__(self, img_size = 32, patch_size = 4, in_channels = 3, embed_dim=256):         super().__init__()          assert img_size % patch_size == 0, "img_size must be divisible by patch_size"         self.patch_size = patch_size         self.num_patches = (img_size//patch_size)**2         self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)      def forward(self, x):         x = self.conv(x) #(B, embed_dim, H/patch_size, W/patch_size)         x = x.flatten(2).transpose(1, 2) #(B, num_patches, embed_dim)         return x

Изображение будет разделено на патчи, и размер каждого патча можно задать с помощью параметра patch_size. При этом изображение не просто разбивается на патчи, но и пропускается через свёрточные ядра (CNN). В итоге мы получаем не просто патчи изображения — а встраивания (эмбеддинги) этих патчей.

Следующий шаг — реализовать самую интересную часть этой модели — механизм внимания (attention).

Self-Attention Mechanism

Self-Attention Mechanism

Q (Query) формально задаёт вопрос от каждого патча к другим патчам, K (Key) показывает, есть ли у каждого патча ответ на этот вопрос, а V (Value) содержит «значения» — фактические данные каждого патча, которые используются для формирования итогового представления.

Предположим, у нас есть X и Y, и мы хотим, чтобы X обращал внимание на Y. В этом случае матрица Query умножается на X, а матрицы Key и Value — на Y. Вместо прямого умножения на матрицы мы используем линейные слои.

attn_probs — это матрицы внимания, которые показывают, насколько токен i должен «обращать внимание» на токен j. Далее мы умножаем их на V, чтобы получить эмбеддинги изображения с учётом весов внимания attn_probs. V фактически хранит значения каждого патча изображения, а attn_probs показывает, сколько информации каждый патч должен получить от остальных патчей.

Вот как работает одна голова внимания; затем значения с всех голов объединяются. Такая конструкция основана на идее, что каждая голова фокусируется на разных аспектах.

class MultiHeadAttention(nn.Module):     def __init__(self, dim, num_heads, dropout):         super().__init__()         assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"         self.num_heads = num_heads          self.head_dim = dim // num_heads          self.qkv = nn.Linear(dim, dim * 3, bias = False)         self.out = nn.Linear(dim, dim, bias = False)          self.scale = 1.0 / (self.head_dim ** 0.5)          self.attn_dropout = nn.Dropout(dropout)      def forward(self,  x, mask = None,  return_attn=False):         B, num_patches, embed_dim = x.shape          qkv = self.qkv(x) # (B, num_patches, 3*embed_dim)         qkv = qkv.reshape(B, num_patches, 3, self.num_heads, self.head_dim)         qkv = qkv.permute(2, 0, 3, 1, 4) #(3, B, num_heads, num_patches, head_dim)          q, k, v = qkv[0], qkv[1], qkv[2]  # each (B, num_heads, num_patches, head_dim)                                          #How important it is for token i to pay attention to token j.         attn_scores = (q @ k.transpose(-2, -1)) * self.scale #[B, num_heads, N, N]          if mask is not None:             # mask: (B, 1, N, N) or (1, 1, N, N)             attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))          attn_probs = attn_scores.softmax(dim=-1) #[B, num_heads, N, N]         attn_probs = self.attn_dropout(attn_probs)         attn_output = attn_probs @ v  # (B, num_heads, num_patches, head_dim)         attn_output = attn_output.transpose(1, 2).reshape(B, num_patches, embed_dim)          if return_attn:           return self.out(attn_output), attn_probs         else:           return self.out(attn_output) #(B, num_patches, embed_dim)

Давайте перейдём к сборке блока Transformer Encoder:

class TransformerEncoderBlock(nn.Module):     def __init__(self, dim, num_heads, mlp_dim, dropout):         super().__init__()         self.norm1 = nn.LayerNorm(dim)         self.attn = MultiHeadAttention(dim, num_heads, dropout)         self.norm2 = nn.LayerNorm(dim)          self.mlp = nn.Sequential(             nn.Linear(dim, mlp_dim),             nn.GELU(),             nn.Dropout(dropout),             nn.Linear(mlp_dim, dim),             nn.Dropout(dropout)         )         self.dropout = nn.Dropout(dropout)      def forward(self, x, return_attn=False):       if return_attn:         attn_out, attn_weights = self.attn(self.norm1(x), return_attn=True)         x = x + self.dropout(attn_out)         x = x + self.dropout(self.mlp(self.norm2(x)))         return x, attn_weights       else:         x = x + self.dropout(self.attn(self.norm1(x)))         x = x + self.dropout(self.mlp(self.norm2(x)))         return x

Здесь мы просто следуем архитектуре нашей сети — все необходимые блоки мы уже реализовали.

Мы уже почти на финишной прямой — теперь соберём сам Vision Transformer:

class VisualTransformer(nn.Module):     def __init__(self,num_classes, img_size=32, patch_size=4, in_channels=3, embed_dim=256,                  num_layers=6, num_heads=7, mlp_dim=512, dropout=0.1):         super().__init__()          self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)         self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))         self.pos_embed = nn.Parameter(torch.randn(1, 1 + self.patch_embed.num_patches, embed_dim))         self.dropout = nn.Dropout(dropout)          self.encoder_blocks = nn.ModuleList([                 TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout)                 for _ in range(num_layers)             ])          self.norm = nn.LayerNorm(embed_dim)          self.mlp_head = nn.Sequential(             nn.LayerNorm(embed_dim),             nn.Linear(in_features=embed_dim, out_features=num_classes)         )      def forward(self, x, return_attn = False):         B = x.size(0)         x = self.patch_embed(x)  # (B, N, D)          cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)         x = torch.cat((cls_tokens, x), dim=1)  # (B, 1+N, D)         x = x + self.pos_embed         x = self.dropout(x)          attn_maps = []         for block in self.encoder_blocks:           if return_attn:             x, attn = block(x, return_attn=True)             attn_maps.append(attn)  # (B, heads, N, N)           else:             x = block(x) # (B, 1+N, D)          x = self.norm(x)          out = self.mlp_head(x[:, 0, :])          if return_attn:           return out, attn_maps         else:           return out

Здесь нужно уточнить несколько моментов. Что такое cls_token? Это специальный токен, который мы добавляем вручную, и он имеет тот же размер, что и патчи изображения. Его задача — использоваться позже для классификации изображения. Идея в том, что, проходя через блоки внимания, этот токен собирает информацию обо всём изображении.

Далее посмотрим на pos_embed. Поскольку мы делим изображение на патчи и выстраиваем их в последовательность — как будто работаем с текстом — модель изначально не понимает пространственные взаимосвязи между патчами. Чтобы это исправить, мы добавляем позиционную информацию к патчам. В нашем случае pos_embed — это обучаемый параметр.

Что касается mlp_head, здесь всё просто: он берёт cls_token, пропускает его через линейный слой и классифицирует изображение.

После сборки нашей модели давайте перейдём к обучению.

Для обучения мы будем использовать следующие гиперпараметры:

BATCH_SIZE = 128 EPOCHS = 80 LEARNING_RATE = 3e-4 PATCH_SIZE = 4 NUM_CLASSES = 10 IMAGE_SIZE = 32 CHANNELS = 3 EMBED_DIM = 256 NUM_HEADS = 8 DEPTH = 6 MLP_DIM = 512 DROP_RATE = 0.1

Давайте посмотрим на количество параметров:

Total parameters: 3,189,514 Trained parameters: 3,189,514

А также следующие аугментации:

train_transforms = transforms.Compose([     transforms.Resize((70, 70)),     transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),     transforms.RandomHorizontalFlip(p=0.5),     transforms.RandomRotation(15),     transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),     transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),     transforms.ToTensor(),     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),     transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'), ])

Мы обучаем модель и получаем следующие результаты:

 Графики процесса обучения

Графики процесса обучения

Давайте посмотрим на предсказания модели:

Предсказания модели

Предсказания модели

Вот метрики получившейся модели:

Метрики модели

Метрики модели
 Матрица ошибок (Confusion Matrix)

Матрица ошибок (Confusion Matrix)

А теперь к самой интересной части — вниманию. Давайте посмотрим, на что наша модель обращает внимание во время классификации:

 Карта внимания (Attention Map)

Карта внимания (Attention Map)
 Карта внимания (Attention Map)

Карта внимания (Attention Map)

В этой статье мы подробно рассмотрели реализацию Vision Transformer и его механизма внимания. Мы изучили, на что способна эта модель и как она «смотрит» на изображение с помощью механизма внимания. Vision Transformer открыл новые направления в исследовании компьютерного зрения, объединив идеи из NLP и обработки изображений. В будущем мы обязательно применим эту модель для задачи генерации подписей к изображениям (Image Captioning).

Полный код и процесс обучения вы можете найти на моём Kaggle:

https://www.kaggle.com/code/nickr0ot/visual-transformer-from-scratch


ссылка на оригинал статьи https://habr.com/ru/articles/925050/