Большие языковые модели это конечно хорошо, но иногда требуется использовать что-то маленькое и быстрое.
Постановка задачи
Дистилляция будет проводиться для модели BERT, обученной на задачу бинарной классификации. В качестве данных был выбран открытый корпус русскоязычных твитов. Вдохновлялся двумя статьями: по дистилляции данных из BERT в BiLSTM, и собственно по дистилляции BERT. Нового ничего не добавлю, хочется все причесать и сделать пошаговый туториал для простого использования. Весь код на github.
План работ
-
Baseline 1: TF-IDF + RandomForest
-
Baseline 2: BiLSTM
-
Дистилляция BERT > BiLSTM
-
Дистилляция BERT > tinyBERT
TF-IDF + RandomForest
Все стандартно: нижний регистр, лемматизация, удаление стоп-слов. Полученные вектора классифицируем RandomForest. Получаем F1 чуть больше 0.75.
Как обучить TF-IDF + RF
import re import pandas as pd from pymystem3 import Mystem # get data data = pd.read_csv('data.csv') texts = list(data['comment']) labels = list(map(int, data['toxic'].values)) # clean texts texts = [re.sub('[^а-яё ]', ' ', str(t).lower()) for t in texts] texts = [re.sub(r" +", " ", t).strip() for t in texts] # lemmatize mstm = Mystem() normalized = [''.join(mstm.lemmatize(t)[:-1]) for t in texts] # remove stopwords with open('./stopwords.txt') as f: stopwords = [line.rstrip('\n') for line in f] def drop_stop(text): tokens = text.split(' ') tokens = [t for t in tokens if t not in stopwords] return ' '.join(tokens) normalized = [drop_stop(text) for text in normalized] # new dataset df = pd.DataFrame() df['text'] = texts df['norm'] = normalized df['label'] = labels # train-valid-test-split from sklearn.model_selection import train_test_split train, test = train_test_split(df, test_size=0.3, random_state=42) valid, test = train_test_split(test, test_size=0.5, random_state=42) # tf-idf from sklearn.feature_extraction.text import TfidfVectorizer model_tfidf = TfidfVectorizer(max_features=5000) train_tfidf = model_tfidf.fit_transform(train['norm'].values) valid_tfidf = model_tfidf.transform(valid['norm'].values) test_tfidf = model_tfidf.transform(test['norm'].values) # RF from sklearn.ensemble import RandomForestClassifier cls = RandomForestClassifier(random_state=42) cls.fit(train_tfidf, train['label'].values) # prediction predictions = cls.predict(test_tfidf) # score from sklearn.metrics import f1_score f1_score(predictions, test['label'].values)
BiLSTM
Попробуем улучшить бэйзлайн с помощью нейросетевого подхода. Все стандартно: учим токенизатор, учим сетку. В качестве базовой архитектуры берем BiLSTM. Получаем F1 чуть больше 0.79. Небольшой, но прирост есть.
Как обучить BiLSTM
# get data import pandas as pd train = pd.read_csv('train.csv') valid = pd.read_csv('valid.csv') test = pd.read_csv('test.csv') # create tokenizer from tokenizers import Tokenizer from tokenizers import ByteLevelBPETokenizer from tokenizers.pre_tokenizers import Whitespace tokenizer = ByteLevelBPETokenizer() tokenizer.pre_tokenizer = Whitespace() tokenizer.enable_padding(pad_id=0, pad_token='<pad>') texts_path = 'texts.txt' with open(texts_path, 'w') as f: for text in list(train['text'].values): f.write("%s\n" % text) tokenizer.train( files=[texts_path], vocab_size=5_000, min_frequency=2, special_tokens=['<pad>', '<unk>'] ) # create dataset import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, tokens, labels, max_len): self.tokens = tokens self.labels = labels self.max_len = max_len def __len__(self): return len(self.tokens) def __getitem__(self, idx): label = self.labels[idx] label = torch.tensor(label) tokens = self.tokens[idx] out = torch.zeros(self.max_len, dtype=torch.long) out[:len(tokens)] = torch.tensor(tokens, dtype=torch.long)[:self.max_len] return out, label max_len = 64 BATCH_SIZE = 16 train_labels = list(train['label']) train_tokens = [tokenizer.encode(text).ids for text in list(train['text'])] train_dataset = CustomDataset(train_tokens, train_labels, max_len) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False) test_labels = list(test['label']) test_tokens = [tokenizer.encode(text).ids for text in list(test['text'])] test_dataset = CustomDataset(test_tokens, test_labels, max_len) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # create BiLSTM class LSTM_classifier(nn.Module): def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2): super().__init__() self.embedding_layer = nn.Embedding(vocab_size, embedding_dim) self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True) self.dropout_layer = nn.Dropout(dropout) self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim) self.batchnorm = nn.BatchNorm1d(linear_dim) self.relu = nn.ReLU() self.out_layer = nn.Linear(linear_dim, n_classes) def forward(self, inputs): batch_size = inputs.size(0) embeddings = self.embedding_layer(inputs) lstm_out, (ht, ct) = self.lstm_layer(embeddings) out = ht.transpose(0, 1) out = out.reshape(batch_size, -1) out = self.fc_layer(out) out = self.batchnorm(out) out = self.relu(out) out = self.dropout_layer(out) out = self.out_layer(out) out = torch.squeeze(out, 1) out = torch.sigmoid(out) return out def init_weights(m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def eval_nn(model, data_loader): predicted = [] labels = [] model.eval() with torch.no_grad(): for data in data_loader: x, y = data x = x.to(device) outputs = model(x) _, predict = torch.max(outputs.data, 1) predict = predict.cpu().detach().numpy().tolist() predicted += predict labels += y score = f1_score(labels, predicted, average='binary') return score def train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=20): best_score = 0 for epoch in range(epochs): model.train() for inputs, labels in tqdm(train_loader): inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() predict = model(inputs) loss = loss_function(predict, labels) loss.backward() optimizer.step() score = eval_nn(model, test_loader) print(epoch, 'valid:', score) if score > best_score: torch.save(model.state_dict(),'lstm.pt') best_score = score return best_score # fit NN device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1) model.apply(init_weights) model.to(device) optimizer = optim.AdamW(model.parameters()) loss_function = nn.CrossEntropyLoss().to(device) train_nn(model, optimizer, loss_function, train_loader, valid_loader, device, epochs=20) eval_nn(model, test_loader)
Учим BERT
Обучим модель-учитель. В качестве учителя я выбрал героя вышеупомянутой статьи по дистилляции — rubert-tiny от @cointegrated. Получаем F1 чуть больше 0.91. Я особо не игрался с обучением, можно думаю было получить метрику и получше, особенно если использовать большой BERT, но и так достаточно показательно. Как обучить BERT на бинарную классификацию можно глянуть в моей прошлой статье, или прямо тут:
как обучить BERT
import torch from torch.utils.data import Dataset class BertDataset(Dataset): def __init__(self, texts, targets, tokenizer, max_len=512): self.texts = texts self.targets = targets self.tokenizer = tokenizer self.max_len = max_len def __len__(self): return len(self.texts) def __getitem__(self, idx): text = str(self.texts[idx]) target = self.targets[idx] encoding = self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=False, padding='max_length', return_attention_mask=True, return_tensors='pt', truncation=True ) return { 'text': text, 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'targets': torch.tensor(target, dtype=torch.long) } from tqdm import tqdm import numpy as np import torch from transformers import BertTokenizer, BertForSequenceClassification from torch.utils.data import Dataset, DataLoader from transformers import AdamW, get_linear_schedule_with_warmup from sklearn.metrics import precision_recall_fscore_support class BertClassifier: def __init__(self, path, n_classes=2): self.path = path self.model = BertForSequenceClassification.from_pretrained(path) self.tokenizer = BertTokenizer.from_pretrained(path) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.max_len = 512 self.out_features = self.model.bert.encoder.layer[1].output.dense.out_features self.model.classifier = torch.nn.Linear(self.out_features, n_classes) self.model.to(self.device) def preparation(self, X_train, y_train, epochs): # create datasets self.train_set = BertDataset(X_train, y_train, self.tokenizer) # create data loaders self.train_loader = DataLoader(self.train_set, batch_size=2, shuffle=True) # helpers initialization self.optimizer = AdamW( self.model.parameters(), lr=2e-5, weight_decay=0.005, correct_bias=True ) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=500, num_training_steps=len(self.train_loader) * epochs ) self.loss_fn = torch.nn.CrossEntropyLoss().to(self.device) def fit(self): self.model = self.model.train() losses = [] correct_predictions = 0 for data in tqdm(self.train_loader): input_ids = data["input_ids"].to(self.device) attention_mask = data["attention_mask"].to(self.device) targets = data["targets"].to(self.device) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask ) preds = torch.argmax(outputs.logits, dim=1) loss = self.loss_fn(outputs.logits, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() train_acc = correct_predictions.double() / len(self.train_set) train_loss = np.mean(losses) return train_acc, train_loss def train(self, X_train, y_train, X_valid, y_valid, X_test, y_test, epochs=1): print('*' * 10) print(f'Model: {self.path}') self.preparation(X_train, y_train, epochs) for epoch in range(epochs): print(f'Epoch {epoch + 1}/{epochs}') train_acc, train_loss = self.fit() print(f'Train loss {train_loss} accuracy {train_acc}') predictions_valid = [self.predict(x) for x in X_valid] precision, recall, f1score = precision_recall_fscore_support(y_valid, predictions_valid, average='macro')[:3] print('Valid:') print(f'precision: {precision}, recall: {recall}, f1score: {f1score}') predictions_test = [self.predict(x) for x in X_test] precision, recall, f1score = precision_recall_fscore_support(y_test, predictions_test, average='macro')[:3] print('Test:') print(f'precision: {precision}, recall: {recall}, f1score: {f1score}') print('*' * 10) def predict(self, text): self.model = self.model.eval() encoding = self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=False, truncation=True, padding='max_length', return_attention_mask=True, return_tensors='pt', ) out = { 'text': text, 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten() } input_ids = out["input_ids"].to(self.device) attention_mask = out["attention_mask"].to(self.device) outputs = self.model( input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0) ) prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0] return prediction import pandas as pd train = pd.read_csv('train.csv') valid = pd.read_csv('valid.csv') test = pd.read_csv('test.csv') classifier = BertClassifier( path='cointegrated/rubert-tiny', n_classes=2 ) classifier.train( X_train=list(train['text']), y_train=list(train['label']), X_valid=list(valid['text']), y_valid=list(valid['label']), X_test=list(test['text']), y_test=list(test['label']), epochs=1 ) path = './trainer' classifier.model.save_pretrained(path) classifier.tokenizer.save_pretrained(path)
Дистилляция BERT > BiLSTM
Основная идея — приближение BiLSTM-учеником выхода BERT-учителя. Для этого при обучении используем функцию ошибки MSE. Можно использовать совместно с обучением на метках и CrossEntropyLoss. Подробнее можно почитать в статье по ссылке. На моих тестовых данных дистилляция докинула всего пару процентов: F1 чуть больше 0.82.
Код дистилляции
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from tokenizers import Tokenizer from tokenizers import ByteLevelBPETokenizer from tokenizers.pre_tokenizers import Whitespace from transformers import BertTokenizer, BertForSequenceClassification from transformers import AdamW, get_linear_schedule_with_warmup from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import f1_score import numpy as np import pandas as pd ### data train = pd.read_csv('train.csv') test = pd.read_csv('test.csv') ### tokenizer: train tokenizer = ByteLevelBPETokenizer() tokenizer.pre_tokenizer = Whitespace() tokenizer.enable_padding(pad_id=0, pad_token='<pad>') texts_path = 'texts.txt' with open(texts_path, 'w') as f: for text in list(train['text'].values): f.write("%s\n" % text) tokenizer.train( files=[texts_path], vocab_size=5_000, min_frequency=2, special_tokens=['<pad>', '<unk>'] ) ### load BERT tokenizer tokenizer_bert = BertTokenizer.from_pretrained('./rubert-tiny') ### dataset class CustomDataset(Dataset): def __init__(self, tokens, labels, max_len): self.tokens = tokens self.labels = labels self.max_len = max_len def __len__(self): return len(self.tokens) def __getitem__(self, idx): label = self.labels[idx] label = torch.tensor(label) tokens = self.tokens[idx] out = torch.zeros(self.max_len, dtype=torch.long) out[:len(tokens)] = torch.tensor(tokens, dtype=torch.long)[:self.max_len] return out, label max_len = 64 BATCH_SIZE = 16 train_labels = list(train['label']) train_tokens = [tokenizer.encode(str(text)).ids for text in list(train['text'])] train_dataset = CustomDataset(train_tokens, train_labels, max_len) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False) test_labels = list(test['label']) test_tokens = [tokenizer.encode(str(text)).ids for text in list(test['text'])] test_dataset = CustomDataset(test_tokens, test_labels, max_len) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) class LSTM_classifier(nn.Module): def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2): super().__init__() self.embedding_layer = nn.Embedding(vocab_size, embedding_dim) self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True) self.dropout_layer = nn.Dropout(dropout) self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim) self.batchnorm = nn.BatchNorm1d(linear_dim) self.relu = nn.ReLU() self.out_layer = nn.Linear(linear_dim, n_classes) def forward(self, inputs): batch_size = inputs.size(0) embeddings = self.embedding_layer(inputs) lstm_out, (ht, ct) = self.lstm_layer(embeddings) out = ht.transpose(0, 1) out = out.reshape(batch_size, -1) out = self.fc_layer(out) out = self.batchnorm(out) out = self.relu(out) out = self.dropout_layer(out) out = self.out_layer(out) out = torch.squeeze(out, 1) out = torch.sigmoid(out) return out ######## def init_weights(m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def eval_nn(model, data_loader): predicted = [] labels = [] model.eval() with torch.no_grad(): for data in data_loader: x, y = data x = x.to(device) outputs = model(x) _, predict = torch.max(outputs.data, 1) predict = predict.cpu().detach().numpy().tolist() predicted += predict labels += y score = f1_score(labels, predicted, average='binary') return score def train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=20): best_score = 0 for epoch in range(epochs): model.train() for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() predict = model(inputs) loss = loss_function(predict, labels) loss.backward() optimizer.step() score = eval_nn(model, test_loader) print(epoch, 'valid:', score) if score > best_score: torch.save(model.state_dict(), 'lstm.pt') best_score = score return best_score device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1) model.apply(init_weights); model.to(device); optimizer = optim.AdamW(model.parameters()) loss_function = nn.CrossEntropyLoss().to(device) train_nn(model, optimizer, loss_function, train_loader, test_loader, device, epochs=3) ##### class DistillDataset(Dataset): def __init__(self, texts, labels, tokenizer_bert, tokenizer_lstm, max_len): self.texts = texts self.labels = labels self.tokenizer_bert = tokenizer_bert self.tokenizer_lstm = tokenizer_lstm self.max_len = max_len def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] label = self.labels[idx] label = torch.tensor(label) # lstm tokens_lstm = self.tokenizer_lstm.encode(str(text)).ids out_lstm = torch.zeros(self.max_len, dtype=torch.long) out_lstm[:len(tokens_lstm)] = torch.tensor(tokens_lstm, dtype=torch.long)[:self.max_len] # bert encoding = self.tokenizer_bert.encode_plus( str(text), add_special_tokens=True, max_length=self.max_len, return_token_type_ids=False, truncation=True, padding='max_length', return_attention_mask=True, return_tensors='pt', ) out_bert = { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten() } return out_lstm, out_bert, label train_dataset_distill = DistillDataset( list(train['text']), list(train['label']), tokenizer_bert, tokenizer, max_len ) train_loader_distill = DataLoader(train_dataset_distill, batch_size=BATCH_SIZE, shuffle=True) ### BERT-teacher model class BertTrainer: def __init__(self, path_model, n_classes=2): self.model = BertForSequenceClassification.from_pretrained(path_model, num_labels=n_classes) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.max_len = 512 self.model.to(self.device) self.model = self.model.eval() def predict(self, inputs): input_ids = inputs["input_ids"].to(self.device) attention_mask = inputs["attention_mask"].to(self.device) with torch.no_grad(): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask ) return outputs.logits teacher = BertTrainer('./rubert-tiny') ### BiLSTM-student model class CustomLSTM(nn.Module): def __init__(self, hidden_dim=128, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.3, n_classes=2): super().__init__() self.embedding_layer = nn.Embedding(vocab_size, embedding_dim) self.lstm_layer = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True) self.dropout_layer = nn.Dropout(dropout) self.fc_layer = nn.Linear(hidden_dim * 2, linear_dim) self.batchnorm = nn.BatchNorm1d(linear_dim) self.relu = nn.ReLU() self.out_layer = nn.Linear(linear_dim, n_classes) def forward(self, inputs): batch_size = inputs.size(0) embeddings = self.embedding_layer(inputs) lstm_out, (ht, ct) = self.lstm_layer(embeddings) out = ht.transpose(0, 1) out = out.reshape(batch_size, -1) out = self.fc_layer(out) out = self.batchnorm(out) out = self.relu(out) out = self.dropout_layer(out) out = self.out_layer(out) # out = torch.squeeze(out, 1) # out = torch.sigmoid(out) return out def loss_function(output, teacher_prob, real_label, a=0.5): criterion_mse = torch.nn.MSELoss() criterion_ce = torch.nn.CrossEntropyLoss() return a * criterion_ce(output, real_label) + (1 - a) * criterion_mse(output, teacher_prob) def init_weights(m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def eval_nn(model, data_loader): predicted = [] labels = [] model.eval() with torch.no_grad(): for data in data_loader: x, y = data x = x.to(device) outputs = model(x) _, predict = torch.max(outputs.data, 1) predict = predict.cpu().detach().numpy().tolist() predicted += predict labels += y score = f1_score(labels, predicted, average='binary') return labels, predicted, score def train_distill(model, teacher, optimizer, loss_function, distill_loader, train_loader, test_loader, device, epochs=30, alpha=0.5): best_score = 0 score_list = [] for epoch in range(epochs): model.train() for inputs, inputs_teacher, labels in distill_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() predict = model(inputs) teacher_predict = teacher.predict(inputs_teacher) loss = loss_function(predict, teacher_predict, labels, alpha) loss.backward() optimizer.step() score_train = round(eval_nn(model, train_loader)[2], 3) score_test = round(eval_nn(model, test_loader)[2], 3) score_list.append((score_train, score_test)) print(epoch, score_train, score_test) if score_test > best_score: best_score = score_test best_model = model torch.save(best_model.state_dict(), f'./results/lstm_{best_score}.pt') return best_model, best_score, score_list device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") vocab_size = tokenizer.get_vocab_size() vocab_size score_alpha = [] for alpha in [0, 0.25, 0.5, 0.75, 1]: model = LSTM_classifier(hidden_dim=256, vocab_size=5000, embedding_dim=300, linear_dim=128, dropout=0.1) model.apply(init_weights) model.to(device) optimizer = optim.AdamW(model.parameters()) _, _, score_list = train_distill(model, teacher, optimizer, loss_function, train_loader_distill, train_loader, test_loader, device, 30, alpha) score_alpha.append(score_list) import matplotlib.pyplot as plt import numpy as np a_list = [1, 0.75, 0.5, 0.25, 0] for i, score in enumerate(score_alpha): _, score_test = list(zip(*score)) plt.plot(score_test, label=f'{a_list[i]}') plt.grid(True) plt.legend() plt.show()
Дистилляция BERT > tinyBERT
Основная идея, как и в прошлом пункте — приближать учеником поведение учителя. Есть много вариантов что и как приближать, я взял всего два:
-
Приближать [CLS]-токен по MSE.
-
Дистилляция распределения токенов по дивергенции Кульбака-Лейблера.
Дополнительно в процессе обучения решаем задачу MLM — предсказание замаскированных токенов. Уменьшение размера модели осуществляется за счет сокращения словаря и уменьшения количества голов внимания, а также количества и размерности скрытых слоев.
Обучение итогового классификатора в итоге делится на 2 этапа:
-
Обучение языковой модели.
-
Обучение головы для классификации.
Я применял дистилляцию только для первого этапа, голову для классификации учил уже непосредственно на дистиллированной модели. Думаю можно было накинуть и вариант с MSE как в примере с BiLSTM, но оставил эти эксперименты на потом.
Ключевые моменты реализации:
Сокращение словаря:
from transformers import BertTokenizerFast, BertForPreTraining, BertModel, BertConfig from collections import Counter from tqdm.auto import tqdm, trange import pandas as pd train = pd.read_csv('train.csv') X_train=list(train['text']) tokenizer = BertTokenizerFast.from_pretrained('./rubert-tiny') cnt = Counter() for text in tqdm(X_train): cnt.update(tokenizer(str(text))['input_ids']) resulting_vocab = { tokenizer.vocab[k] for k in tokenizer.special_tokens_map.values() } for k, v in cnt.items(): if v > 5: resulting_vocab.add(k) resulting_vocab = sorted(resulting_vocab) tokenizer.save_pretrained('./bert_distill'); inv_voc = {idx: word for word, idx in tokenizer.vocab.items()} with open('./bert_distill/vocab.txt', 'w', encoding='utf-8') as f: for idx in resulting_vocab: f.write(inv_voc[idx] + '\n')
Инициализация весов
config = BertConfig( emb_size=256, hidden_size=256, intermediate_size=256, max_position_embeddings=512, num_attention_heads=8, num_hidden_layers=3, vocab_size=tokenizer_distill.vocab_size ) model = BertForPreTraining(config) model.save_pretrained('./bert_distill') from transformers import BertModel # load model without CLS-head teacher = BertForPreTraining.from_pretrained('./rubert-tiny') tokenizer_teacher = BertTokenizerFast.from_pretrained('./rubert-tiny') # copy input embeddings accordingly with resulting_vocab model.bert.embeddings.word_embeddings.weight.data = teacher.bert.embeddings.word_embeddings.weight.data[resulting_vocab, :256].clone() model.bert.embeddings.position_embeddings.weight.data = teacher.bert.embeddings.position_embeddings.weight.data[:, :256].clone() # copy output embeddings model.cls.predictions.decoder.weight.data = teacher.cls.predictions.decoder.weight.data[resulting_vocab, :256].clone()
MLM-loss
inputs = tokenizer_distill(texts, return_tensors='pt', padding=True, truncation=True, max_length=16) inputs = preprocess_inputs(inputs, tokenizer_distill, data_collator) outputs = model(**inputs, output_hidden_states=True) loss += nn.CrossEntropyLoss( outputs.prediction_logits.view(-1, model.config.vocab_size), inputs['labels'].view(-1) )
KL-loss
def loss_kl(inputs, outputs, model, teacher, vocab_mapping, temperature=1.0): new_inputs = torch.tensor( [[vocab_mapping[i] for i in row] for row in inputs['input_ids']] ).to(inputs['input_ids'].device) with torch.no_grad(): teacher_out = teacher( input_ids=new_inputs, token_type_ids=inputs['token_type_ids'], attention_mask=inputs['attention_mask'] ) # the whole batch, all tokens after the [cls], the whole dimension kd_loss = torch.nn.KLDivLoss(reduction='batchmean')( F.log_softmax(outputs.prediction_logits[:, 1:, :] / temperature, dim=1), F.softmax(teacher_out.prediction_logits[:, 1:, vocab_mapping] / temperature, dim=1) ) / outputs.prediction_logits.shape[-1] return kd_loss
MSE-loss
input_teacher = {k: v for k, v in tokenizer_teacher( texts, return_tensors='pt', padding=True, max_length=16, truncation=True ).items()} with torch.no_grad(): out_teacher = teacher_mse(**input_teacher) embeddings_teacher_norm = torch.nn.functional.normalize(out_teacher.pooler_output) input_distill = {k: v for k, v in tokenizer_distill( texts, return_tensors='pt', padding=True, max_length=16, truncation=True ).items()} out = model(**input_distill, output_hidden_states=True) embeddings = model.bert.pooler(out.hidden_states[-1]) embeddings_norm = torch.nn.functional.normalize(adapter_emb(embeddings)) loss = torch.nn.MSELoss(embeddings_norm, embeddings_teacher_norm)
Размер итоговой модели составил 16 Мб, метрика F1 0.86. Учил модель я 12 часов на макбук эйр 19 года с i5 и 8 Гб оперативной памяти. Думаю, если погонять подольше, то и результат будет получше.
Код и данные для обучения представлены на github, замечания, дополнения и исправления приветствуются.
ссылка на оригинал статьи https://habr.com/ru/post/692570/
Добавить комментарий