Обучение трансформера на синтетическом датасете

от автора

Библиотек будет использовано минимум.

import random import math  import numpy as np import pandas as pd import matplotlib.pyplot as plt from tqdm.notebook import tqdm  import torch import torch.utils.data as tdutils from torch import nn, optim 

Определение dataset’а

Обучение любой нейронки начинается именно с этого. Для целей поиграться датасет можно взять синтетический. Например, нагенерить из арифметических выражений. Если в выражении «2+2=4» случайным образом заменить один символ, получится достаточно простая, но нетривиальная задача по коррекции ошибок. Чтобы задать датасет в пригодном для использовании в pytorch виде, нужно создать класс наследник IterableDataset и переопределить метод __iter__

OPS = '+-*/%' DIGITS = '0123456789' CHARS = ' ' + DIGITS + OPS + '=' OPS_METHODS = {     '+': lambda v1, v2: v1 + v2,     '-': lambda v1, v2: v1 - v2,     '*': lambda v1, v2: v1 * v2,     '/': lambda v1, v2: 0 if v2 == 0 else v1 // v2,     '%': lambda v1, v2: v1 % v2  }  class SampleSet(tdutils.IterableDataset):     def __init__(self, val_min=0, val_max=99):         self.val_min, self.val_max = val_min, val_max         assert val_min > 0         max_res = val_max * val_max         self.str_size = len(f'{val_max}*{val_max}={max_res}')              def __iter__(self):         while True:             yield self.make_sample()                  def to_tensor(self, str_value):         res = torch.zeros([self.str_size], dtype=torch.uint8)         converted = torch.tensor([             CHARS.index(char) for char in str_value         ])         res[0:len(converted)] = converted         return res                           def make_sample(self):         val1 = random.randint(self.val_min, self.val_max)         val2 = random.randint(self.val_min, self.val_max)         op = OPS[random.randint(0, len(OPS) - 1)]         res = OPS_METHODS[op](val1, val2)          original = f'{val1}{op}{val2}={res}'         lst = list(original)         lst[random.randint(0, len(original) - 1)] = CHARS[random.randint(0, len(CHARS)-1)]         replaced = ''.join(lst)                                  return {             'task': self.to_tensor(replaced),             'answer': self.to_tensor(original)         }               # код для проверки _sample = SampleSet(1, 99).make_sample() _sample # output {   'task': tensor([ 4, 10, 11,  6,  6, 16,  2,  2,  5,  0], dtype=torch.uint8),   'answer': tensor([ 6, 10, 11,  6,  6, 16,  2,  2,  5,  0], dtype=torch.uint8) }

Полученный датасет уже можно скармливать в DataLoader

_dataset = tdutils.DataLoader(     dataset=SampleSet(1, 9),     batch_size=8 ) _dataset_iter = iter(_dataset) _batch = next(_dataset_iter) _batch  # output {'task': tensor([[ 7,  2,  7, 16,  2,  0],          [ 8, 13,  2, 16,  8,  0],          [ 4, 12,  4, 16, 15,  0],          [ 7, 13,  7, 18,  4,  7],          [ 2, 15,  2, 16,  2,  0],          [ 2, 15,  9, 16,  2,  0],          [ 6, 13,  3,  5,  2,  1],          [10, 13,  3, 14,  2,  9]], dtype=torch.uint8),  'answer': tensor([[ 7, 14,  7, 16,  2,  0],          [ 8, 13,  2, 16,  8,  0],          [ 4, 12,  4, 16,  1,  0],          [ 7, 13,  7, 16,  4,  7],          [ 2, 15,  4, 16,  2,  0],          [ 2, 15,  5, 16,  2,  0],          [ 6, 13,  3, 16,  2,  1],          [10, 13,  3, 16,  2,  9]], dtype=torch.uint8)}

А еще для работы с датасетом не помешают функции отображения

def tensor_to_str(tensor):         res = ''.join([         CHARS[val] for val in tensor     ])     return res.strip(' ')           def show_sample(dct):     task = tensor_to_str(dct['task'])     answer = tensor_to_str(dct['answer'])     return f'{task}->{answer}'  show_sample(_sample) # output '39+55=114->59+55=114'

Embeddings

Вся магия трансформера начинается с перевода входной последовательности в векторное представление. Сделать это можно встроенным в pytorch модулем nn.Embedding. Он поддерживает внутри себя словарик (тензор размером d_chars * d_models)

class Embed(nn.Module):     def __init__(self, d_chars, d_model):         super().__init__()                         self.embedding = nn.Embedding(d_chars, d_model)              def forward(self, batch):         return self.embedding(batch['task'].long())         _embed = Embed(len(CHARS), 32)(_batch) _embed.shape # output torch.Size([8, 6, 32])

Размерность пространства d_model это гиперпараметр, который нужно будет подбирать вне процедуры градиентного спуска. Размерность данных не меняется на протяжении всего пути их прохождения через модель и на выходе понадобится обратное преобразование. Его впринципе можно сделать на основе того же словарика, но проще просто выучить отдельным полносвязным слоем

class DecodeEmbed(nn.Module):     def __init__(self, d_chars, d_model):         super().__init__()         self.decode = nn.Linear(d_model, d_chars)         self.softmax = nn.Softmax(dim=-1)              def forward(self, embed):                 return torch.argmax(self.softmax(self.decode(embed)), dim=-1)      DecodeEmbed(len(CHARS), 32)(_embed).shape # output torch.Size([8, 6])

Attention

Внимание это механизм, позволяющий сетке направить пристальный взгляд на какой-то из входных элемент обрабатываемой последовательности. Eсли более формально, то каждому элементу в последовательности приписывается некоторый ключ K и значение V. Дальше векторным запросом Q можно запросить нужную информацию. Существует много способов как конкретно это сделать. В трансформе за основу взят DotProduct Attention. В нем в качестве весов, с которыми нужно складывать значение, берется softmax произведения Q и K

\mathrm{softmax}(Q \cdot K^T) \cdot V

Как на пальцах это работает. Путь информация о том, что именно лежит в символах последовательности закодирована в двухмерном пространстве.

_k = torch.tensor([     [1, 1], [-1, 1], [0.01, 0.02] ]).float() _q = torch.tensor([     [-1, 1], [1, 1], [0, 1] ]).float() * 10  _v = torch.tensor([     [0, 1, 2, 3],     [4, 5, 6, 7],     [8, 9, 10, 11] ]).float()

Вектора K первого и второго символа ортогональны. В третьем лежит какой-то шум. Вектор _q имеет похожую структуру

torch.matmul(_q, _k.T) # output tensor([[ 0.0000, 20.0000,  0.1000],         [20.0000,  0.0000,  0.3000],         [10.0000, 10.0000,  0.2000]])

После матричного умножения в видно, что в первый символ нужно записать содержание второго, а во второй первый. Запрос в третьей позиции лежит посередине между первым и вторым символом. В матрице он представлен равными весами для первого и второго. Третий символ (последний столбец) не вносит заметно вклада в итоговый результат. Картину еще более сглаживает применение softmax’а

_atw = nn.Softmax(dim=-1)(torch.matmul(_q, _k.T)) (_atw * 100).numpy().astype(int) #output  array([[  0, 100,   0],        [100,   0,   0],        [ 49,  49,   0]])

Результат предсказуем

_att_res = torch.matmul(_atw, _v) _att_res.cpu().numpy().astype(int) # output array([[4, 5, 6, 7],        [0, 1, 2, 3],        [2, 3, 4, 5]])

В ответе первые и вторые строки переставлены, а в последней строчке лежит их среднее

Но это еще не все. В трансформере используется multihead attention. Оно состоит из dot product головок пониженной размерности. Кол-во головок это еще один гиперпараметр. Выводы отдельных голов конкатенируются и выучиваемым преобразованием трансформируются в исходное пространство. А чтобы сетке было проще обучаться в каждой из головок добавленно деление на корень из размерности. Просто, чтобы дисперсия на выходе была такая же, как и на входе.

class Attention(nn.Module):     def forward(self, q, k, v):         sel = torch.matmul(q, k.transpose(-1, -2))         weights = nn.Softmax(dim=-1)(sel / math.sqrt(k.shape[-1]))         return torch.matmul(weights, v)                  class ProjectedAttention(nn.Module):     def __init__(self, d_qk1, d_qk2, d_v1, d_v2):         super().__init__()         self.keys = nn.Linear(d_qk1, d_qk2)         self.queries = nn.Linear(d_qk1, d_qk2)         self.values = nn.Linear(d_v1, d_v2)         self.att = Attention()              def forward(self, q, k, v):         return self.att(             self.queries(q),             self.keys(k),             self.values(v)         )                   class MultiHeadAttention(nn.Module):     def __init__(self, d_model, h):         super().__init__()         self.heads = nn.ModuleList([             ProjectedAttention(                 d_model, d_model // h,                  d_model, d_model // h             )              for _ in range(h)         ])                    self.final = nn.Linear(d_model, d_model)              def forward(self, q, k, v):         head_res = torch.cat([             head(q, k, v)             for head in self.heads         ], dim=-1)         return self.final(head_res)          _model = MultiHeadAttention(d_model=32, h=2) _multi = _model(_embed, _embed, _embed) _multi.shape # output torch.Size([8, 6, 32])

Positional Encoding

Трансформер не имеет никакой другой связи с информацией, записанной в соседних токенах, кроме механизма внимания и, чтобы снабдить его возможностью запрашивать содержимое соседних ячеек, к эмбендингам добавляется позиционное кодирование. Как это делается проще понять в комплексной нотации: вещественное пространство четной размерности v_len можно представить как комплексное размерностью в двое меньше. Для позиции pos в компоненте l прибавляется вектор

e^{\left( \mathrm{pos} \cdot M^{l} \cdot \sqrt{-1} \right)}

где M — некоторое маленькое число. При переходе к вещественным числам там возникает sin и cos. При таком кодировании обозначение позиций со смещением в ту или другую сторону могут быть получены из текущего линейным преобразованием.

e^{\left( (\mathrm{pos}+ k) \cdot M^{l} \cdot \sqrt{-1} \right)}=e^{\left( \mathrm{pos} \cdot M^{l} \cdot \sqrt{-1} \right)}\cdot e^{\left( k \cdot M^{l} \cdot \sqrt{-1} \right)}

M = 1/10000  def pos_tensor(seq_len, v_len):     power = 2 * torch.arange(v_len // 2).float() / v_len     arg = torch.outer(         torch.arange(seq_len),         M ** power              )     res = torch.cat([torch.sin(arg), torch.cos(arg)], dim=-1)     return res      class PositionInfo(nn.Module):     def forward(self, data):         pos_info = pos_tensor(data.shape[-2], data.shape[-1]).to(data.device)         return data + pos_info      PositionInfo()(_embed).shape

Encoders и Decoders

Трансформер состоит из Encoder и Decoder блоков. Encoder это просто multihead self attention слой с последующим feed-forward слоем. Feed forward состоит из двух полносвязных слоев с одинаковыми для всех позиций весами (свертка с ядром 1). Осложнена эта картина skip connection’ами и пакетной нормализацией.

class BigBatch(nn.Module):     def __init__(self, net):         super().__init__()         self.net = net              def forward(self, batch):         first_size = batch.shape[0]         second_size = batch.shape[1]         big_batch = batch.reshape(             first_size * second_size, *tuple(batch.shape[2:])         )         big_res = self.net(big_batch)         return big_res.reshape(             first_size, second_size, *tuple(big_res.shape[1:])         )          class Encoder(nn.Module):     def __init__(self, d_model, h):         super().__init__()         self.self_att = MultiHeadAttention(d_model, h)         self.norm1 = BigBatch(nn.BatchNorm1d(d_model))         self.feed_forward = nn.Sequential(             BigBatch(nn.Linear(d_model, d_model)),             nn.ReLU(),             BigBatch(nn.Linear(d_model, d_model)),         )         self.norm2 = BigBatch(nn.BatchNorm1d(d_model))                        def forward(self, data):         res1 = self.self_att(data, data, data)         res1r = self.norm1(data + res1)         res2 = self.feed_forward(res1r)         res2r = self.norm2(res1r + res2)         return res2r          _encode = Encoder(32, 2)(_embed) _encode.shape  # output torch.Size([8, 6, 32])

Decoder чуть сложнее. Кроме self attention’а у него есть слой внимания в выхлопу encoder’а

class Decoder(nn.Module):     def __init__(self, d_model, h):         super().__init__()         self.self_att = MultiHeadAttention(d_model, h)         self.norm1 = BigBatch(nn.BatchNorm1d(d_model))                  self.src_att = MultiHeadAttention(d_model, h)         self.norm2 = BigBatch(nn.BatchNorm1d(d_model))                  self.feed_forward = nn.Sequential(             BigBatch(nn.Linear(d_model, d_model)),             nn.ReLU(),             BigBatch(nn.Linear(d_model, d_model)),         )         self.norm3 = BigBatch(nn.BatchNorm1d(d_model))                       def forward(self, src, tgt):         res1r = self.self_att(tgt, tgt, tgt)         res1 = self.norm1(res1r + tgt)         res2r = self.src_att(res1, res1, src)         res2 = self.norm2(res1 + res2r)         res3r = self.feed_forward(res2)         res3 = self.norm3(res2 + res3r)         return res3          _decode = Decoder(32, 2)(_embed, _embed) _decode.shape # output torch.Size([8, 6, 32])

Все в сборе

Модель состоит из embedding’ов, несколько слоев Encoder’а, Decoder’а и финального выходного слоя.

class Model(nn.Module):     def __init__(self, d_chars, d_model, h, n_layers=2):         super().__init__()         self.embed = Embed(d_chars, d_model)         self.pos = PositionInfo()         self.encoders = nn.ModuleList([             Encoder(d_model, h) for _ in range(n_layers)         ])         self.decoders = nn.ModuleList([             Decoder(d_model, h) for _ in range(n_layers)         ])         self.embed_decoder = nn.Linear(d_model, d_chars)                       def forward(self, batch):         enc_out = self.pos(self.embed(batch))         for layer in self.encoders:             enc_out = layer(enc_out)         dec_out = enc_out         for layer in self.decoders:             dec_out = layer(enc_out, dec_out)         char_out = self.embed_decoder(dec_out)         return char_out               _model_out = Model(len(CHARS), 32, 2, 2)(_batch) _model_out.shape # output torch.Size([8, 6, 19])

Маленькая оговорка: все, да не все. Здесь нет dropout’ов и mask’ed attention’а. В оригинале декодеру позволено запрашивать только предыдущие значения, чтобы сделать модель авторегрессионной. Здесь же, для такой простой задачи, не очень понятно, зачем это может понадобиться.

Процедура обучения

Первое что нужно сделать: описать функцию потерь. Для такой задачи можно взять cross entropy loss для отдельных символов.

def mean_batch_loss(answer, model_out):     return nn.CrossEntropyLoss()(         model_out.permute(0, 2, 1),         answer.long()     )      mean_batch_loss(_batch['answer'], _model_out) # output tensor(2.7526, grad_fn=<NllLoss2DBackward>)

Чтобы не путаться, все параметры обучения можно загнать в отдельный объект.

DEVICE = 'cuda:0'  class Context():     def __init__(self):         self.train_epoch = 10000         self.val_samples = 10000         self.device = DEVICE         self.batch_loss = 0         self.batch_discont = 0.8         self.history = []          _ctx = Context() _ctx.model = Model(len(CHARS), 32, 2, 2).to(_ctx.device) _ctx.opt = optim.SGD(_ctx.model.parameters(), lr=0.01) _ctx.epoch_size = 1000 _ctx.val_samples = 1000

Учится будем на GPU и нам пригодится функция, отправляющая туда данные

def to_device(device, data):     if isinstance(data, torch.Tensor):         return data.to(device)     elif isinstance(data, dict):         return {             key: to_device(device, value)             for key, value in data.items()         }      to_device(DEVICE, _batch) # output {'task': tensor([[ 7,  2,  7, 16,  2,  0],          [ 8, 13,  2, 16,  8,  0],          [ 4, 12,  4, 16, 15,  0],          [ 7, 13,  7, 18,  4,  7],          [ 2, 15,  2, 16,  2,  0],          [ 2, 15,  9, 16,  2,  0],          [ 6, 13,  3,  5,  2,  1],          [10, 13,  3, 14,  2,  9]], device='cuda:0', dtype=torch.uint8),  'answer': tensor([[ 7, 14,  7, 16,  2,  0],          [ 8, 13,  2, 16,  8,  0],          [ 4, 12,  4, 16,  1,  0],          [ 7, 13,  7, 16,  4,  7],          [ 2, 15,  4, 16,  2,  0],          [ 2, 15,  5, 16,  2,  0],          [ 6, 13,  3, 16,  2,  1],          [10, 13,  3, 16,  2,  9]], device='cuda:0', dtype=torch.uint8)}

Основа основ — процедура скармливания отдельного batch’а в сетку

def feed_batch(ctx, batch):     ctx.opt.zero_grad()     on_device = to_device(ctx.device, batch)     model_out = ctx.model(on_device)     loss = mean_batch_loss(on_device['answer'], model_out)     loss.backward()     ctx.opt.step()     ctx.batch_loss = (         (1 - ctx.batch_discont) * ctx.batch_loss         + ctx.batch_discont * loss.detach().cpu().numpy()     )   feed_batch(_ctx, _batch)

Всю нудную работу по расчету градиентов pytorch берет на себя. Для каждого тензора (если явно не указать обратное) он поддерживает значение текущего накопленного градиента а также функции, делающие back propagation. Нам же остается сбросить его перед началом обработки пакета (zero_grad), и вызвать loss.backward(), opt.step() после применения модели.

Процедура обучения обычно долгая. На текущее значение лоса полезно поглядывать. Вдруг он ушел в вверх или застопорился. Для этого можно хранить средний batch_loss c экспоненциальным backoff’ом.

Но им одним дело не ограничивается. Хорошо бы периодически честно пересчитывать метрики. В качестве метрик можно взять точность по символам и точность по семплам. Функция feed_epoch служим именно этой цели. Она скармливает в сетку некоторое кол-во семплов, считает метрики и записывает их для истории.

def calc_metrics(ctx, dataset_iter):     ctx.model.eval()     counters = {         'batch': 0,         'batch_loss': 0.0,         'batch_char_acc': 0.0,         'sample_acc': 0.0     }     with torch.no_grad():         with tqdm(total=ctx.val_samples, leave=False) as pbar:             num_samples = 0             while num_samples < ctx.val_samples:                 batch = next(dataset_iter)                 batch_size = len(batch['task'])                 on_device = to_device(ctx.device, batch)                                  pred = ctx.model(on_device)                 counters['batch'] += 1                 counters['batch_loss'] += mean_batch_loss(                     on_device['answer'],                      pred                 ).cpu().numpy()                 char_pred = torch.argmax(pred, dim=-1)                 correct = char_pred == on_device['answer']                                                   counters['batch_char_acc'] += correct.float().mean().cpu().numpy()                 counters['sample_acc'] += correct.min(dim=-1).values.float().mean().cpu().numpy()                                  num_samples += batch_size                 pbar.update(batch_size)                      return {         'loss': counters['batch_loss'] / counters['batch'],         'char_acc': counters['batch_char_acc'] / counters['batch'],          'sample_acc': counters['sample_acc'] / counters['batch']     }                      calc_metrics(_ctx, _dataset_iter)
def feed_epoch(ctx, dataset_iter):     ctx.model.train()     num_samples = 0     with tqdm(total=ctx.train_epoch, leave=False) as pbar:         while num_samples < ctx.train_epoch:             batch = next(dataset_iter)             feed_batch(ctx, batch)                      batch_size = len(batch['task'])             num_samples += batch_size             pbar.update(batch_size)                                        pbar.set_postfix(batch_loss = ctx.batch_loss)     ctx.history.append(calc_metrics(ctx, dataset_iter))                    _ctx.train_epoch = 10000             feed_epoch(_ctx, _dataset_iter)

Чистовое обучение

Теперь все готово, чтобы все взять и обучить. Размер батча и learning rate подобраны для Тесла V100. Для карточки по-скромнее размер пакета нужно брать по-меньше (чтобы по памяти не вылететь), и learning rate, соответственно, тоже (чтобы градиент нешибко шатало).

dataset = tdutils.DataLoader(     dataset=SampleSet(1, 99999999),     batch_size=1024 * 2,     num_workers=8 ) dataset_iter = iter(dataset)
ctx = Context() ctx.train_epoch = 10000 ctx.val_samples = 5000 ctx.model = Model(len(CHARS), 256, h=8, n_layers=3).to(ctx.device) ctx.opt = optim.SGD(ctx.model.parameters(), lr=1.0) ctx.seq_len = dataset.dataset.str_size ctx.history = []
%%time _num_epoches = 100 for _ in tqdm(range(_num_epoches), leave=False):     feed_epoch(ctx, dataset_iter) # output CPU times: user 4min 34s, sys: 1min 47s, total: 6min 22s Wall time: 6min 20s

Вывод истории

Самая простая модель, которую можно взять для baseline’а — тупое копирование.

def plot_history(ctx):     plt.subplot(1, 3, 1)     loss = [ record['loss'] for record in ctx.history]         plt.plot(loss)     plt.title('loss')          plt.subplot(1, 3, 2)     char_acc = [ record['char_acc'] for record in ctx.history]     plt.plot(char_acc)     baseline = (ctx.seq_len - 1) / ctx.seq_len + 1/ctx.seq_len * 1/len(CHARS)     plt.plot([0, len(char_acc) - 1], [baseline, baseline], c='g')     plt.plot([0, len(char_acc) - 1], [1, 1], c='g')      plt.title('char_acc')          plt.subplot(1, 3, 3)     baseline = (1/len(CHARS))     sample_acc = [record['sample_acc'] for record in ctx.history]     plt.plot(sample_acc)     plt.plot([0, len(sample_acc) - 1], [baseline, baseline], c='g')     plt.title('sample_acc')      plt.figure(figsize=(15,5)) plot_history(ctx)
Вывод истории обучения
Вывод истории обучения

Модель учится 🙂

ссылка на оригинал статьи https://habr.com/ru/post/542334/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *