Logit Lens & ViT model: туториал

от автора

Привет, Хабр!

В этом туториале разобран метод для анализа внутренних представлений «логит-линза» (Logit Lens).

В результате практики по туториалу, вы:

  1. Изучите подход и концепцию Logit Lens;

  2. Реализуете Logit Lens для Visual Transformer;

  3. Познакомитесь с анализом результатов применения логит-линзы.

Приступим! Как всегда, весь код будет на гитхаб — step by step.

Logit Lens: о методе

Метод Logit Lens был предложен на Lessworng в 2020 году на примере модели GPT-2.

Сама по себе линза является достаточно сложным методов в реализации — под каждую модель необходимо проектировать идею её построения заново, основываясь на анализе архитектуры. В то же время, это дает методу преимущество:

Метод является model-specific (то есть точечно концентрируется на одной модели), что позволяет изучать характеры поведения скрытых представлений в разных моделях.

Как они соотносятся? Быть может, именно вы сможете найти что-то интересное и я надеюсь, этот туториал вас вдохновит!

Теория

Прежде чем начать, рассмотрим базу метода. Он основан на логитах и скрытых состояниях — важно знать, что это. Если уже знаете, можно листать эти разделы.

Логиты

Вспомним задачу бинарной классификации при помощи логистической регрессии модели. Кроме вероятностной постановки, можно посмотреть так:

Хотим: прогнозировать вероятность конкретного класса (отрезок $[0,1]$). \

Проблема: Линейная модель не ограничена.

Для решения используется использовать логит-преобразование — сначала мы учим модель с диапазоном значений (−∞,∞), и потом преобразовываем ответы в вероятность, при помощи сигмоиды.

В этом случае говорят, что мы пронозируем логиты — числа, равные  z = \ln(\frac{p}{1-p}), где p — искомая вероятность. А потом преобразуем их при помощи сигмоидальной функции:

p(y=1 | x) = \frac{1}{1+e^{-z}}

Аналогичная идея применяется и в задаче многоклассовой классификации. Там в роли преобразования выступает softmax функция. Формально, если z_i— логиты для всех токенов i, то мы вычисляем вектор вероятностей p, где на месте i-й координаты, вероятность вычисляется как:

p(y_i \mid x) = \frac{e^{z_i}}{\sum_j e^{z_j}}

Таким образом, обобщая, логиты можно определить так:

Логиты — это входные значения в функцию softmax/sigmoid перед преобразованием в вероятности. Для каждого токена они представляют собой сырые «пре-«вероятности (logits), используемые для выбора следующего слова в последовательности.

Это несколько расходится с формулировкой из статистики, но в данном случае необходимо для понимания задачи.

Скрытое состояние

Понятие скрытого состояния связано со скрытыми слоями в модели (ваш кэп). Вообще, термин «скрытый» слой значит следующее:

Скрытым слоем модели Net(x) назовем любой слой архитектуры, имеющий выход, кроме первого (входного) и последнего (выходного) слоев.

Из этого определения, еслм Net(x) — глубокая нейронная сеть произвольной архитектуры с количеством слоев от 1 до N, то скрытое состояние модели, есть представление h^{(l)} — выход, после l-го слоя, где l \neq N, l \neq 1.

Logit Lens

Вооружившись определениями, теперь мы можем рассмотреть идею и построение логит-линз.

Любая большая модель может быть рассмотрена как последовательность скрытых состояний h^{(l)}(x_{l-1}), где x_{l-1} выход предыдущего слоя.

Расписывая детально, получим что для модели с Net(x):

input \to Net(input) \to output

при условии N слоев, справедливо абстрактное представление:

input \to h^{(0)} \to h^{(1)} \to h^{(2)} \to ... \to h^{(k)} \to .... \to \to h^{(N-1)} \to output

Основываясь на этом, если преобразования h^{(l)} согласованы друг с другом (то есть имеют один и тот же размер), то мы можем исследовать, как модель формирует предсказания на разных слоях проецируя представление скрытого слоя в вектор выходных значений.

В этом случае, логит-линзой скрытого слоя h^{(l)} называется проекция:

h^{(l')} = W_{\text{out}} h_l

где W_{\text{out}} — веса выходного слоя.

Метод подходит не для всех моделей. Главное условие применимости — описанная согласованность архитектуры.

Ура, на теорию посмотрели — перейдем к практике!

Практика.

В ноутбуке, который вы можете найти здесь, реализована загрузка набора данных и модели. Посмотрим на то, с чем будем работать:

  1. Картинки

# Display the files num_columns = 4 num_rows = 2  fig, axes = plt.subplots(num_rows, num_columns, figsize=(20, 8))  for i, ax in enumerate(axes.flatten()):     ax.axis('off')     if i < len(images_list):         ax.imshow(images_list[i])         ax.set_title(f"Image {image_files[i]}")  plt.show()
Да, у нас снова практика с красивыми животными!

Да, у нас снова практика с красивыми животными!

2. Модель

# import the image processor image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")  # import the model model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") model.eval();

Анализ модели

Очень важный шаг при анализе скрытых представлений — анализ модели. И мы не опустим его здесь. Чтобы понять, откуда будут браться логиты, рассмотрим архитектуру модели и информацию о выходах скрытых слоев. Для этого нам понадобится какой-то входной пример и функция summary из библиотеки torchinfo.

input0 = image_processor(images_list[0], return_tensors="pt",) summary(model, input_data=input0['pixel_values'], depth=6)
Первые слои модели (дальше повтор слоёв энкодера)

Первые слои модели (дальше повтор слоёв энкодера)
Последние слои модели

Последние слои модели

Заметим, что энкодер и эмбеддинг слой уточняют признаки одинаковых размеров. Но эмбеддинги (ViTEmbeddings) и энкодер (ViTEncoder) выполняют разные функции:

  1. Эмбеддинги (ViTEmbeddings)
    Отвечают за преобразование входного изображения в последовательность токенов (векторов признаков);

  2. Энкодер (ViTEncoder)
    Отвечает за обработку последовательности эмбеддингов с помощью, извлекая из них сложные пространственные и контекстные зависимости.

Исходя из этого будем рассматривать для проекции в линзу выходные слои энкодера (ViTEncoder).

Извлечение скрытых представлений

В библиотеке transformers возвращение скрытых состояний можно получить при помощи простого атрибута при инференсе.

Синтаксис примерно такой:

model = Net()
outputs = model(**inputs, output_hidden_states=True)

Извлечем их из модели.

outputs = model(**input0, output_hidden_states=True)  # Extracting the predicted class. predicted_class_idx = torch.sigmoid(outputs.logits).argmax(-1).item() labels = model.config.id2label  # indexes and lanels dict  print(f"Predicted: {labels[predicted_class_idx]}") print(f'Количество извлеченных скрытых состояний: {len(outputs.hidden_states)}')  # OUT #Predicted: brown bear, bruin, Ursus arctos #Количество извлеченных скрытых состояний: 13

И так, мы получили верный прогноз — модель видит медведя и 13 скрытых состояний. Здесь важно:

  1. Первое (‘outputs.hidden_states[0]’) — выход части эмбеддинга model.vit.embeddings

  2. Состояния 2-13 — выходы частей энкодера — все до применения к ним самого последнего слоя компоненты model.vitlayernorm

Нулевое отбросим. Нас будут интересновать с 1-го по 13е. И посмотрим, как извлекать проекции.

Посмотрим, как ивзлекать проекции. С одной стороны, можно создавать дочернюю модель с меньшим числом слоёв энкодера. Такой подход описан в статье здесь (там тоже есть код).

Мы же пойдем иначе, скрытые состояния мы уже извлекли. Теперь, по опредлению, нам нужно построить:

h^{(l')} = W_{\text{out}} h_l

То есть, пропустить каждое скрытое состояние через классификационные веса. Они находятся в модуле модели classifier. Извлечем его

classification_layer = model.classifier

И сейчас мы готовы построить линзы. Однако, если посмотреть на размерность скрытых состояний, каждое скрытое состояние имеет размерность [1, 197, 768], применяя линейный слой получится вектор (1, 197) — один класс на каждый патч (если вы не знаете термин пачта — это, грубо говоря, компонента, на которую модель разбивает изображение для извлечения признаков. Подробно тут).

С чем это связано?

В модели ViT для классификации используется CLS токен. Он соответствует 0-му из патчей. Мы можем взять либо его, либо медиану/среднее по патчам.

Реализуем оба способа и посмотрим на результат.

cls_hidden = hidden_states[-1][:, 0, :] # extracting CLS token median_hidden = torch.median(hidden_states[-1], dim=1)[0] # extracting median  # extracting classification result (logits) classification_result_cls = classification_layer(cls_hidden) classification_result_median = classification_layer(median_hidden)  # extracting classification result (probas) predicted_class_cls = torch.softmax(classification_result_cls, dim=-1) predicted_class_median = torch.softmax(classification_result_cls, dim=-1)  print(f'Predicted class on CLS token: {predicted_class_cls.argmax()}') # 294 print(f'Predicted class on median: {predicted_class_median.argmax()}')  # 294

Для последнего скрытого состояния результаты одинаковы. И на этом моменте у нас есть всё, чтобы пробежать по всем скрытым состояниям — от понимания теории, до понимания практических тонкостей! Напишем для этого функцию.

def get_hidden_state_result(hidden_state, true_class_idx, mode='CLS'):    if mode == 'CLS':     hidden_state = hidden_state[:, 0, :] # Extract the hidden state corresponding to the [CLS] token (first token in the sequence)   if mode == 'median':  # Take the median across the patches (dim=1 refers to the second dimension which is the patches)     hidden_state = torch.median(hidden_state, dim=1)[0]    classification_layer_result = classification_layer(hidden_state)  # Apply classification layer   predicted_class_probas = torch.softmax(classification_layer_result, dim=-1) # Apply softmax to get class probabilities  # Get the predicted class index and probability   predicted_class_idx = predicted_class_probas.argmax().item()   predicted_class_proba = predicted_class_probas.max().item()  # Get the rank of the true class  #Sort the predicted probabilities in descending order and find the rank of the true class # by looking for the true class index in the sorted tensor.    _, indexes_sorted = torch.sort(predicted_class_probas, descending = True)   final_label_rank = torch.where(indexes_sorted == true_class_idx)[-1].item()    return predicted_class_proba, predicted_class_idx, final_label_rank

Функция возвращает вероятность прогнозированного класса на представлении h^{(l)}, индекс этого класса и ранг — порядковый номер в отсортированном векторе вероятностей. Пример запуска функции на медведе:

true_class_idx = 294  bear_image_probas = [] bear_image_predictions = [] bear_image_true_ranks = []  for hidden_state in hidden_states:   predicted_class_proba, predicted_class_idx, final_label_rank = get_hidden_state_result(hidden_state, true_class_idx, mode='CLS')    bear_image_probas.append(predicted_class_proba)   bear_image_predictions.append(predicted_class_idx)   bear_image_true_ranks.append(final_label_rank)     bear_table = [bear_image_probas, bear_image_predictions, bear_image_true_ranks]   

Результат будет выглядеть так.

fig, ax = plt.subplots(figsize=(20, 6))  sns.heatmap(bear_table, annot=True, ax=ax, cbar=False, fmt='g')  ax.set_xticklabels([f'layer {i+1}' for i in range(12)]); ax.set_yticklabels(['Predicted class probas', 'Predicted index', 'True label rank']); ax.set_title('Results for bear image');

Так как на одном изображении не совсем корректно делать выводы, то я построила для всех (это есть в полной тетради). Получается, что:

  • На некоторых примерах прогноз может стабилизироваться на ранних слоях (4-5 для horses, bird).

  • Истинный (по прогнозу) класс оказываются в топ-3, на последних слоях, но не для всех изображений это верно (контрпример — sheep)

  • Уверенность (выраженная вероятностью) в пронозе резко возрастает на двух последних слоях (третий с конца может быть как стабильным по уверенности, так и нет (например, cow c 0.15 и bird с 0.9)).

На этом нехитрый анализ методом Logit Lens завершается. Надеюсь, мне удалось познакомить вас с методом и донести его идею. Благодарю вас за время, уделенное туториалу! Надеюсь, также он вас на что-то вдохновил 😌

Полный ноутбук здесь (на русском и английском). В репозитория вы также можете найти другие туториалы (почти по всем есть статьи на Хабре от меня же).

Хороших и красивых проектов!
Ваш Дата-автор!


ссылка на оригинал статьи https://habr.com/ru/articles/891352/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *