Дистилляция BERT для задачи классификации

от автора

Большие языковые модели это конечно хорошо, но иногда требуется использовать что-то маленькое и быстрое.

Постановка задачи

Дистилляция будет проводиться для модели BERT, обученной на задачу бинарной классификации. В качестве данных был выбран открытый корпус русскоязычных твитов. Вдохновлялся двумя статьями: по дистилляции данных из BERT в BiLSTM, и собственно по дистилляции BERT. Нового ничего не добавлю, хочется все причесать и сделать пошаговый туториал для простого использования. Весь код на github.

План работ

  1. Baseline 1: TF-IDF + RandomForest

  2. Baseline 2: BiLSTM

  3. Дистилляция BERT > BiLSTM

  4. Дистилляция BERT > tinyBERT

TF-IDF + RandomForest

Все стандартно: нижний регистр, лемматизация, удаление стоп-слов. Полученные вектора классифицируем RandomForest. Получаем F1 чуть больше 0.75.

Как обучить TF-IDF + RF
import re import pandas as pd from pymystem3 import Mystem  # get data data = pd.read_csv('data.csv')  texts = list(data['comment']) labels = list(map(int, data['toxic'].values))  # clean texts texts = [re.sub('[^а-яё ]', ' ', str(t).lower()) for t in texts] texts = [re.sub(r" +", " ", t).strip() for t in texts]  # lemmatize mstm = Mystem()  normalized = [''.join(mstm.lemmatize(t)[:-1]) for t in texts]  # remove stopwords with open('./stopwords.txt') as f:     stopwords = [line.rstrip('\n') for line in f]  def drop_stop(text):     tokens = text.split(' ')     tokens = [t for t in tokens if t not in stopwords]     return ' '.join(tokens)  normalized = [drop_stop(text) for text in normalized]  # new dataset df = pd.DataFrame() df['text'] = texts df['norm'] = normalized df['label'] = labels  # train-valid-test-split from sklearn.model_selection import train_test_split  train, test = train_test_split(df, test_size=0.3, random_state=42) valid, test = train_test_split(test, test_size=0.5, random_state=42)  # tf-idf from sklearn.feature_extraction.text import TfidfVectorizer  model_tfidf = TfidfVectorizer(max_features=5000)  train_tfidf = model_tfidf.fit_transform(train['norm'].values) valid_tfidf = model_tfidf.transform(valid['norm'].values) test_tfidf = model_tfidf.transform(test['norm'].values)  # RF from sklearn.ensemble import RandomForestClassifier  cls = RandomForestClassifier(random_state=42) cls.fit(train_tfidf, train['label'].values)  # prediction predictions = cls.predict(test_tfidf)  # score from sklearn.metrics import f1_score  f1_score(predictions, test['label'].values)

BiLSTM

Попробуем улучшить бэйзлайн с помощью нейросетевого подхода. Все стандартно: учим токенизатор, учим сетку. В качестве базовой архитектуры берем BiLSTM. Получаем F1 чуть больше 0.79. Небольшой, но прирост есть.

Как обучить BiLSTM
# get data import pandas as pd  train = pd.read_csv('train.csv') valid = pd.read_csv('valid.csv') test = pd.read_csv('test.csv')  # create tokenizer from tokenizers import Tokenizer from tokenizers import ByteLevelBPETokenizer from tokenizers.pre_tokenizers import Whitespace  tokenizer = ByteLevelBPETokenizer() tokenizer.pre_tokenizer = Whitespace() tokenizer.enable_padding(pad_id=0, pad_token='<pad>')  texts_path = 'texts.txt'  with open(texts_path, 'w') as f:     for text in list(train['text'].values):         f.write("%s\n" % text)  tokenizer.train(     files=[texts_path],     vocab_size=5_000,     min_frequency=2,     special_tokens=['<pad>', '<unk>']     )  # create dataset import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader  class CustomDataset(Dataset):      def __init__(self, tokens, labels, max_len):         self.tokens = tokens         self.labels = labels         self.max_len = max_len       def __len__(self):         return len(self.tokens)       def __getitem__(self, idx):         label = self.labels[idx]         label = torch.tensor(label)         tokens = self.tokens[idx]         out = torch.zeros(self.max_len, dtype=torch.long)         out[:len(tokens)] = torch.tensor(tokens, dtype=torch.long)[:self.max_len]         return out, label  max_len = 64 BATCH_SIZE = 16  train_labels = list(train['label']) train_tokens = [tokenizer.encode(text).ids for text in list(train['text'])] train_dataset = CustomDataset(train_tokens, train_labels, max_len) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)  test_labels = list(test['label']) test_tokens = [tokenizer.encode(text).ids for text in list(test['text'])] test_dataset = CustomDataset(test_tokens, test_labels, max_len) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)  # create BiLSTM class LSTM_classifier(nn.Module):       def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):         super().__init__()         self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)         self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)         self.dropout_layer = nn.Dropout(dropout)                 self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)         self.batchnorm = nn.BatchNorm1d(linear_dim)         self.relu = nn.ReLU()         self.out_layer = nn.Linear(linear_dim, n_classes)       def forward(self, inputs):         batch_size = inputs.size(0)         embeddings = self.embedding_layer(inputs)         lstm_out, (ht, ct) = self.lstm_layer(embeddings)         out = ht.transpose(0, 1)         out = out.reshape(batch_size, -1)         out = self.fc_layer(out)         out = self.batchnorm(out)         out = self.relu(out)         out = self.dropout_layer(out)         out = self.out_layer(out)         out = torch.squeeze(out, 1)         out = torch.sigmoid(out)         return out  def init_weights(m):     if isinstance(m, nn.Linear):         torch.nn.init.xavier_uniform_(m.weight)         m.bias.data.fill_(0.01)  def eval_nn(model, data_loader):     predicted = []     labels = []     model.eval()     with torch.no_grad():         for data in data_loader:             x, y = data             x = x.to(device)             outputs = model(x)             _, predict = torch.max(outputs.data, 1)             predict = predict.cpu().detach().numpy().tolist()             predicted += predict             labels += y         score = f1_score(labels, predicted, average='binary')     return score  def train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=20):     best_score = 0     for epoch in range(epochs):         model.train()         for inputs, labels in tqdm(train_loader):             inputs, labels = inputs.to(device), labels.to(device)             optimizer.zero_grad()             predict = model(inputs)             loss = loss_function(predict, labels)             loss.backward()             optimizer.step()        score = eval_nn(model, test_loader)         print(epoch, 'valid:', score)         if score > best_score:             torch.save(model.state_dict(),'lstm.pt')             best_score = score     return best_score  # fit NN device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)     model.apply(init_weights)  model.to(device)  optimizer = optim.AdamW(model.parameters())  loss_function = nn.CrossEntropyLoss().to(device)  train_nn(model, optimizer, loss_function, train_loader, valid_loader, device, epochs=20)  eval_nn(model, test_loader)

Учим BERT

Обучим модель-учитель. В качестве учителя я выбрал героя вышеупомянутой статьи по дистилляцииrubert-tiny от @cointegrated. Получаем F1 чуть больше 0.91. Я особо не игрался с обучением, можно думаю было получить метрику и получше, особенно если использовать большой BERT, но и так достаточно показательно. Как обучить BERT на бинарную классификацию можно глянуть в моей прошлой статье, или прямо тут:

как обучить BERT
import torch from torch.utils.data import Dataset  class BertDataset(Dataset):    def __init__(self, texts, targets, tokenizer, max_len=512):     self.texts = texts     self.targets = targets     self.tokenizer = tokenizer     self.max_len = max_len    def __len__(self):     return len(self.texts)    def __getitem__(self, idx):     text = str(self.texts[idx])     target = self.targets[idx]      encoding = self.tokenizer.encode_plus(         text,         add_special_tokens=True,         max_length=self.max_len,         return_token_type_ids=False,         padding='max_length',         return_attention_mask=True,         return_tensors='pt',         truncation=True     )      return {       'text': text,       'input_ids': encoding['input_ids'].flatten(),       'attention_mask': encoding['attention_mask'].flatten(),       'targets': torch.tensor(target, dtype=torch.long)     }  from tqdm import tqdm import numpy as np import torch from transformers import BertTokenizer, BertForSequenceClassification from torch.utils.data import Dataset, DataLoader from transformers import AdamW, get_linear_schedule_with_warmup from sklearn.metrics import precision_recall_fscore_support   class BertClassifier:      def __init__(self, path, n_classes=2):         self.path = path         self.model = BertForSequenceClassification.from_pretrained(path)         self.tokenizer = BertTokenizer.from_pretrained(path)         self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")         self.max_len = 512         self.out_features = self.model.bert.encoder.layer[1].output.dense.out_features         self.model.classifier = torch.nn.Linear(self.out_features, n_classes)         self.model.to(self.device)           def preparation(self, X_train, y_train, epochs):         # create datasets         self.train_set = BertDataset(X_train, y_train, self.tokenizer)         # create data loaders         self.train_loader = DataLoader(self.train_set, batch_size=2, shuffle=True)         # helpers initialization         self.optimizer = AdamW(             self.model.parameters(),             lr=2e-5,             weight_decay=0.005,             correct_bias=True             )         self.scheduler = get_linear_schedule_with_warmup(                 self.optimizer,                 num_warmup_steps=500,                 num_training_steps=len(self.train_loader) * epochs             )         self.loss_fn = torch.nn.CrossEntropyLoss().to(self.device)       def fit(self):         self.model = self.model.train()         losses = []         correct_predictions = 0          for data in tqdm(self.train_loader):             input_ids = data["input_ids"].to(self.device)             attention_mask = data["attention_mask"].to(self.device)             targets = data["targets"].to(self.device)              outputs = self.model(                 input_ids=input_ids,                 attention_mask=attention_mask                 )              preds = torch.argmax(outputs.logits, dim=1)             loss = self.loss_fn(outputs.logits, targets)              correct_predictions += torch.sum(preds == targets)              losses.append(loss.item())                          loss.backward()             torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)             self.optimizer.step()             self.scheduler.step()             self.optimizer.zero_grad()          train_acc = correct_predictions.double() / len(self.train_set)         train_loss = np.mean(losses)         return train_acc, train_loss           def train(self, X_train, y_train, X_valid, y_valid, X_test, y_test, epochs=1):         print('*' * 10)         print(f'Model: {self.path}')         self.preparation(X_train, y_train, epochs)         for epoch in range(epochs):             print(f'Epoch {epoch + 1}/{epochs}')             train_acc, train_loss = self.fit()             print(f'Train loss {train_loss} accuracy {train_acc}')             predictions_valid = [self.predict(x) for x in X_valid]             precision, recall, f1score = precision_recall_fscore_support(y_valid, predictions_valid, average='macro')[:3]             print('Valid:')             print(f'precision: {precision}, recall: {recall}, f1score: {f1score}')             predictions_test = [self.predict(x) for x in X_test]             precision, recall, f1score = precision_recall_fscore_support(y_test, predictions_test, average='macro')[:3]             print('Test:')             print(f'precision: {precision}, recall: {recall}, f1score: {f1score}')         print('*' * 10)          def predict(self, text):         self.model = self.model.eval()         encoding = self.tokenizer.encode_plus(             text,             add_special_tokens=True,             max_length=self.max_len,             return_token_type_ids=False,             truncation=True,             padding='max_length',             return_attention_mask=True,             return_tensors='pt',         )                  out = {               'text': text,               'input_ids': encoding['input_ids'].flatten(),               'attention_mask': encoding['attention_mask'].flatten()           }                  input_ids = out["input_ids"].to(self.device)         attention_mask = out["attention_mask"].to(self.device)                  outputs = self.model(             input_ids=input_ids.unsqueeze(0),             attention_mask=attention_mask.unsqueeze(0)         )                  prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0]          return prediction  import pandas as pd  train = pd.read_csv('train.csv') valid = pd.read_csv('valid.csv') test = pd.read_csv('test.csv')  classifier = BertClassifier(     path='cointegrated/rubert-tiny',     n_classes=2 )  classifier.train(         X_train=list(train['text']),         y_train=list(train['label']),         X_valid=list(valid['text']),         y_valid=list(valid['label']),         X_test=list(test['text']),         y_test=list(test['label']),         epochs=1 )  path = './trainer' classifier.model.save_pretrained(path) classifier.tokenizer.save_pretrained(path)

Дистилляция BERT > BiLSTM

Основная идея — приближение BiLSTM-учеником выхода BERT-учителя. Для этого при обучении используем функцию ошибки MSE. Можно использовать совместно с обучением на метках и CrossEntropyLoss. Подробнее можно почитать в статье по ссылке. На моих тестовых данных дистилляция докинула всего пару процентов: F1 чуть больше 0.82.

Код дистилляции
import torch import torch.nn as nn import torch.optim as optim  from torch.utils.data import Dataset, DataLoader  from tokenizers import Tokenizer from tokenizers import ByteLevelBPETokenizer from tokenizers.pre_tokenizers import Whitespace  from transformers import BertTokenizer, BertForSequenceClassification from transformers import AdamW, get_linear_schedule_with_warmup  from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import f1_score  import numpy as np import pandas as pd  ### data  train = pd.read_csv('train.csv') test = pd.read_csv('test.csv')  ### tokenizer: train  tokenizer = ByteLevelBPETokenizer() tokenizer.pre_tokenizer = Whitespace() tokenizer.enable_padding(pad_id=0, pad_token='<pad>')  texts_path = 'texts.txt'  with open(texts_path, 'w') as f:     for text in list(train['text'].values):         f.write("%s\n" % text)  tokenizer.train(     files=[texts_path],     vocab_size=5_000,     min_frequency=2,     special_tokens=['<pad>', '<unk>']     )  ### load BERT tokenizer  tokenizer_bert = BertTokenizer.from_pretrained('./rubert-tiny')  ### dataset  class CustomDataset(Dataset):      def __init__(self, tokens, labels, max_len):         self.tokens = tokens         self.labels = labels         self.max_len = max_len       def __len__(self):         return len(self.tokens)       def __getitem__(self, idx):         label = self.labels[idx]         label = torch.tensor(label)         tokens = self.tokens[idx]         out = torch.zeros(self.max_len, dtype=torch.long)         out[:len(tokens)] = torch.tensor(tokens, dtype=torch.long)[:self.max_len]         return out, label  max_len = 64 BATCH_SIZE = 16  train_labels = list(train['label']) train_tokens = [tokenizer.encode(str(text)).ids for text in list(train['text'])] train_dataset = CustomDataset(train_tokens, train_labels, max_len) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)  test_labels = list(test['label']) test_tokens = [tokenizer.encode(str(text)).ids for text in list(test['text'])] test_dataset = CustomDataset(test_tokens, test_labels, max_len) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)  class LSTM_classifier(nn.Module):       def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):         super().__init__()         self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)         self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)         self.dropout_layer = nn.Dropout(dropout)                 self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)         self.batchnorm = nn.BatchNorm1d(linear_dim)         self.relu = nn.ReLU()         self.out_layer = nn.Linear(linear_dim, n_classes)       def forward(self, inputs):         batch_size = inputs.size(0)         embeddings = self.embedding_layer(inputs)         lstm_out, (ht, ct) = self.lstm_layer(embeddings)         out = ht.transpose(0, 1)         out = out.reshape(batch_size, -1)         out = self.fc_layer(out)         out = self.batchnorm(out)         out = self.relu(out)         out = self.dropout_layer(out)         out = self.out_layer(out)         out = torch.squeeze(out, 1)         out = torch.sigmoid(out)         return out  ########  def init_weights(m):     if isinstance(m, nn.Linear):         torch.nn.init.xavier_uniform_(m.weight)         m.bias.data.fill_(0.01)  def eval_nn(model, data_loader):     predicted = []     labels = []     model.eval()     with torch.no_grad():         for data in data_loader:             x, y = data             x = x.to(device)             outputs = model(x)             _, predict = torch.max(outputs.data, 1)             predict = predict.cpu().detach().numpy().tolist()             predicted += predict             labels += y         score = f1_score(labels, predicted, average='binary')     return score  def train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=20):     best_score = 0     for epoch in range(epochs):         model.train()         for inputs, labels in train_loader:             inputs, labels = inputs.to(device), labels.to(device)             optimizer.zero_grad()             predict = model(inputs)             loss = loss_function(predict, labels)             loss.backward()             optimizer.step()         score = eval_nn(model, test_loader)         print(epoch, 'valid:', score)         if score > best_score:             torch.save(model.state_dict(), 'lstm.pt')             best_score = score     return best_score  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)     model.apply(init_weights);  model.to(device);  optimizer = optim.AdamW(model.parameters())  loss_function = nn.CrossEntropyLoss().to(device)  train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=3)  #####  class DistillDataset(Dataset):      def __init__(self, texts, labels, tokenizer_bert, tokenizer_lstm, max_len):         self.texts = texts         self.labels = labels         self.tokenizer_bert = tokenizer_bert         self.tokenizer_lstm = tokenizer_lstm         self.max_len = max_len       def __len__(self):         return len(self.texts)       def __getitem__(self, idx):         text = self.texts[idx]         label = self.labels[idx]         label = torch.tensor(label)         # lstm         tokens_lstm = self.tokenizer_lstm.encode(str(text)).ids         out_lstm = torch.zeros(self.max_len, dtype=torch.long)         out_lstm[:len(tokens_lstm)] = torch.tensor(tokens_lstm, dtype=torch.long)[:self.max_len]         # bert         encoding = self.tokenizer_bert.encode_plus(             str(text),             add_special_tokens=True,             max_length=self.max_len,             return_token_type_ids=False,             truncation=True,             padding='max_length',             return_attention_mask=True,             return_tensors='pt',         )                  out_bert = {               'input_ids': encoding['input_ids'].flatten(),               'attention_mask': encoding['attention_mask'].flatten()         }         return out_lstm, out_bert, label  train_dataset_distill = DistillDataset(     list(train['text']),     list(train['label']),     tokenizer_bert,     tokenizer,     max_len )  train_loader_distill = DataLoader(train_dataset_distill, batch_size=BATCH_SIZE, shuffle=True)  ### BERT-teacher model  class BertTrainer:      def __init__(self, path_model, n_classes=2):         self.model = BertForSequenceClassification.from_pretrained(path_model, num_labels=n_classes)         self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")         self.max_len = 512         self.model.to(self.device)         self.model = self.model.eval()          def predict(self, inputs):              input_ids = inputs["input_ids"].to(self.device)         attention_mask = inputs["attention_mask"].to(self.device)         with torch.no_grad():             outputs = self.model(                 input_ids=input_ids,                 attention_mask=attention_mask             )         return outputs.logits  teacher = BertTrainer('./rubert-tiny')  ### BiLSTM-student model  class CustomLSTM(nn.Module):       def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2):         super().__init__()         self.embedding_layer = nn.Embedding(vocab_size, embedding_dim)         self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)         self.dropout_layer = nn.Dropout(dropout)                 self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim)         self.batchnorm = nn.BatchNorm1d(linear_dim)         self.relu = nn.ReLU()         self.out_layer = nn.Linear(linear_dim, n_classes)       def forward(self, inputs):         batch_size = inputs.size(0)         embeddings = self.embedding_layer(inputs)         lstm_out, (ht, ct) = self.lstm_layer(embeddings)         out = ht.transpose(0, 1)         out = out.reshape(batch_size, -1)         out = self.fc_layer(out)         out = self.batchnorm(out)         out = self.relu(out)         out = self.dropout_layer(out)         out = self.out_layer(out) #         out = torch.squeeze(out, 1) #         out = torch.sigmoid(out)         return out  def loss_function(output, teacher_prob, real_label, a=0.5):     criterion_mse = torch.nn.MSELoss()     criterion_ce = torch.nn.CrossEntropyLoss()     return a * criterion_ce(output, real_label) + (1 - a) * criterion_mse(output, teacher_prob)  def init_weights(m):     if isinstance(m, nn.Linear):         torch.nn.init.xavier_uniform_(m.weight)         m.bias.data.fill_(0.01)  def eval_nn(model, data_loader):     predicted = []     labels = []     model.eval()     with torch.no_grad():         for data in data_loader:             x, y = data             x = x.to(device)             outputs = model(x)             _, predict = torch.max(outputs.data, 1)             predict = predict.cpu().detach().numpy().tolist()             predicted += predict             labels += y         score = f1_score(labels, predicted, average='binary')     return labels, predicted, score  def train_distill(model, teacher, optimizer, loss_function, distill_loader, train_loader, test_loader, device, epochs=30, alpha=0.5):     best_score = 0     score_list = []     for epoch in range(epochs):         model.train()         for inputs, inputs_teacher, labels in distill_loader:             inputs, labels = inputs.to(device), labels.to(device)             optimizer.zero_grad()             predict = model(inputs)             teacher_predict = teacher.predict(inputs_teacher)             loss = loss_function(predict, teacher_predict, labels, alpha)             loss.backward()             optimizer.step()         score_train = round(eval_nn(model, train_loader)[2], 3)         score_test = round(eval_nn(model, test_loader)[2], 3)         score_list.append((score_train, score_test))         print(epoch, score_train, score_test)         if score_test > best_score:                  best_score = score_test             best_model = model     torch.save(best_model.state_dict(), f'./results/lstm_{best_score}.pt')     return best_model, best_score, score_list  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  vocab_size = tokenizer.get_vocab_size()  vocab_size  score_alpha = [] for alpha in [0, 0.25, 0.5, 0.75, 1]:     model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1)         model.apply(init_weights)     model.to(device)     optimizer = optim.AdamW(model.parameters())     _, _, score_list = train_distill(model, teacher, optimizer, loss_function, train_loader_distill, train_loader, test_loader, device, 30, alpha)     score_alpha.append(score_list)  import matplotlib.pyplot as plt import numpy as np  a_list = [1, 0.75, 0.5, 0.25, 0]  for i, score in enumerate(score_alpha):     _, score_test = list(zip(*score))     plt.plot(score_test, label=f'{a_list[i]}') plt.grid(True) plt.legend() plt.show()

Дистилляция BERT > tinyBERT

Основная идея, как и в прошлом пункте — приближать учеником поведение учителя. Есть много вариантов что и как приближать, я взял всего два:

  1. Приближать [CLS]-токен по MSE.

  2. Дистилляция распределения токенов по дивергенции Кульбака-Лейблера.

Дополнительно в процессе обучения решаем задачу MLM — предсказание замаскированных токенов. Уменьшение размера модели осуществляется за счет сокращения словаря и уменьшения количества голов внимания, а также количества и размерности скрытых слоев.

Обучение итогового классификатора в итоге делится на 2 этапа:

  1. Обучение языковой модели.

  2. Обучение головы для классификации.

Я применял дистилляцию только для первого этапа, голову для классификации учил уже непосредственно на дистиллированной модели. Думаю можно было накинуть и вариант с MSE как в примере с BiLSTM, но оставил эти эксперименты на потом.

Ключевые моменты реализации:

Сокращение словаря:
from transformers import BertTokenizerFast, BertForPreTraining, BertModel, BertConfig from collections import Counter from tqdm.auto import tqdm, trange import pandas as pd  train = pd.read_csv('train.csv') X_train=list(train['text'])  tokenizer = BertTokenizerFast.from_pretrained('./rubert-tiny')  cnt = Counter() for text in tqdm(X_train):     cnt.update(tokenizer(str(text))['input_ids'])  resulting_vocab = {     tokenizer.vocab[k] for k in tokenizer.special_tokens_map.values() }  for k, v in cnt.items():     if v > 5:         resulting_vocab.add(k)  resulting_vocab = sorted(resulting_vocab)  tokenizer.save_pretrained('./bert_distill');  inv_voc = {idx: word for word, idx in tokenizer.vocab.items()}  with open('./bert_distill/vocab.txt', 'w', encoding='utf-8') as f:     for idx in resulting_vocab:         f.write(inv_voc[idx] + '\n')

Инициализация весов
config = BertConfig(     emb_size=256,     hidden_size=256,     intermediate_size=256,     max_position_embeddings=512,     num_attention_heads=8,     num_hidden_layers=3,     vocab_size=tokenizer_distill.vocab_size )  model = BertForPreTraining(config)  model.save_pretrained('./bert_distill')  from transformers import BertModel # load model without CLS-head teacher = BertForPreTraining.from_pretrained('./rubert-tiny')  tokenizer_teacher = BertTokenizerFast.from_pretrained('./rubert-tiny')  # copy input embeddings accordingly with resulting_vocab model.bert.embeddings.word_embeddings.weight.data = teacher.bert.embeddings.word_embeddings.weight.data[resulting_vocab, :256].clone() model.bert.embeddings.position_embeddings.weight.data = teacher.bert.embeddings.position_embeddings.weight.data[:, :256].clone()  # copy output embeddings model.cls.predictions.decoder.weight.data = teacher.cls.predictions.decoder.weight.data[resulting_vocab, :256].clone()

MLM-loss
inputs = tokenizer_distill(texts, return_tensors='pt', padding=True, truncation=True, max_length=16) inputs = preprocess_inputs(inputs, tokenizer_distill, data_collator) outputs = model(**inputs, output_hidden_states=True) loss += nn.CrossEntropyLoss(         outputs.prediction_logits.view(-1, model.config.vocab_size),         inputs['labels'].view(-1)     )

KL-loss
def loss_kl(inputs, outputs, model, teacher, vocab_mapping, temperature=1.0):     new_inputs = torch.tensor(         [[vocab_mapping[i] for i in row] for row in inputs['input_ids']]     ).to(inputs['input_ids'].device)     with torch.no_grad():         teacher_out = teacher(             input_ids=new_inputs,              token_type_ids=inputs['token_type_ids'],             attention_mask=inputs['attention_mask']         )     # the whole batch, all tokens after the [cls], the whole dimension     kd_loss = torch.nn.KLDivLoss(reduction='batchmean')(         F.log_softmax(outputs.prediction_logits[:, 1:, :] / temperature, dim=1),          F.softmax(teacher_out.prediction_logits[:, 1:, vocab_mapping] / temperature, dim=1)     ) / outputs.prediction_logits.shape[-1]     return kd_loss

MSE-loss
input_teacher = {k: v for k, v in tokenizer_teacher(         texts,         return_tensors='pt',         padding=True,         max_length=16,         truncation=True     ).items()}  with torch.no_grad():     out_teacher = teacher_mse(**input_teacher)  embeddings_teacher_norm = torch.nn.functional.normalize(out_teacher.pooler_output)  input_distill = {k: v for k, v in tokenizer_distill(         texts,         return_tensors='pt',         padding=True,         max_length=16,         truncation=True     ).items()}  out = model(**input_distill, output_hidden_states=True) embeddings = model.bert.pooler(out.hidden_states[-1]) embeddings_norm = torch.nn.functional.normalize(adapter_emb(embeddings)) loss = torch.nn.MSELoss(embeddings_norm, embeddings_teacher_norm)

Размер итоговой модели составил 16 Мб, метрика F1 0.86. Учил модель я 12 часов на макбук эйр 19 года с i5 и 8 Гб оперативной памяти. Думаю, если погонять подольше, то и результат будет получше.

Код и данные для обучения представлены на github, замечания, дополнения и исправления приветствуются.


ссылка на оригинал статьи https://habr.com/ru/post/692570/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *