Привет, Хабр!
В этом туториале разобран метод для анализа внутренних представлений «логит-линза» (Logit Lens).
В результате практики по туториалу, вы:
-
Изучите подход и концепцию Logit Lens;
-
Реализуете Logit Lens для Visual Transformer;
-
Познакомитесь с анализом результатов применения логит-линзы.
Приступим! Как всегда, весь код будет на гитхаб — step by step.
Logit Lens: о методе
Метод Logit Lens был предложен на Lessworng в 2020 году на примере модели GPT-2.
Сама по себе линза является достаточно сложным методов в реализации — под каждую модель необходимо проектировать идею её построения заново, основываясь на анализе архитектуры. В то же время, это дает методу преимущество:
— Метод является model-specific (то есть точечно концентрируется на одной модели), что позволяет изучать характеры поведения скрытых представлений в разных моделях.
Как они соотносятся? Быть может, именно вы сможете найти что-то интересное и я надеюсь, этот туториал вас вдохновит!
Теория
Прежде чем начать, рассмотрим базу метода. Он основан на логитах и скрытых состояниях — важно знать, что это. Если уже знаете, можно листать эти разделы.
Логиты
Вспомним задачу бинарной классификации при помощи логистической регрессии модели. Кроме вероятностной постановки, можно посмотреть так:
Хотим: прогнозировать вероятность конкретного класса (отрезок $[0,1]$). \
Проблема: Линейная модель не ограничена.
Для решения используется использовать логит-преобразование — сначала мы учим модель с диапазоном значений , и потом преобразовываем ответы в вероятность, при помощи сигмоиды.
В этом случае говорят, что мы пронозируем логиты — числа, равные , где
— искомая вероятность. А потом преобразуем их при помощи сигмоидальной функции:
Аналогичная идея применяется и в задаче многоклассовой классификации. Там в роли преобразования выступает softmax функция. Формально, если — логиты для всех токенов
, то мы вычисляем вектор вероятностей
, где на месте
-й координаты, вероятность вычисляется как:
Таким образом, обобщая, логиты можно определить так:
Логиты — это входные значения в функцию softmax/sigmoid перед преобразованием в вероятности. Для каждого токена они представляют собой сырые «пре-«вероятности (logits), используемые для выбора следующего слова в последовательности.
Это несколько расходится с формулировкой из статистики, но в данном случае необходимо для понимания задачи.
Скрытое состояние
Понятие скрытого состояния связано со скрытыми слоями в модели (ваш кэп). Вообще, термин «скрытый» слой значит следующее:
Скрытым слоем модели назовем любой слой архитектуры, имеющий выход, кроме первого (входного) и последнего (выходного) слоев.
Из этого определения, еслм — глубокая нейронная сеть произвольной архитектуры с количеством слоев от 1 до N, то скрытое состояние модели, есть представление
— выход, после
-го слоя, где
.
Logit Lens
Вооружившись определениями, теперь мы можем рассмотреть идею и построение логит-линз.
Любая большая модель может быть рассмотрена как последовательность скрытых состояний , где
выход предыдущего слоя.
Расписывая детально, получим что для модели с :
при условии слоев, справедливо абстрактное представление:
Основываясь на этом, если преобразования согласованы друг с другом (то есть имеют один и тот же размер), то мы можем исследовать, как модель формирует предсказания на разных слоях проецируя представление скрытого слоя в вектор выходных значений.
В этом случае, логит-линзой скрытого слоя называется проекция:
где — веса выходного слоя.
Метод подходит не для всех моделей. Главное условие применимости — описанная согласованность архитектуры.
Ура, на теорию посмотрели — перейдем к практике!
Практика.
В ноутбуке, который вы можете найти здесь, реализована загрузка набора данных и модели. Посмотрим на то, с чем будем работать:
-
Картинки
# Display the files num_columns = 4 num_rows = 2 fig, axes = plt.subplots(num_rows, num_columns, figsize=(20, 8)) for i, ax in enumerate(axes.flatten()): ax.axis('off') if i < len(images_list): ax.imshow(images_list[i]) ax.set_title(f"Image {image_files[i]}") plt.show()

2. Модель
# import the image processor image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") # import the model model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") model.eval();
Анализ модели
Очень важный шаг при анализе скрытых представлений — анализ модели. И мы не опустим его здесь. Чтобы понять, откуда будут браться логиты, рассмотрим архитектуру модели и информацию о выходах скрытых слоев. Для этого нам понадобится какой-то входной пример и функция summary
из библиотеки torchinfo
.
input0 = image_processor(images_list[0], return_tensors="pt",) summary(model, input_data=input0['pixel_values'], depth=6)


Заметим, что энкодер и эмбеддинг слой уточняют признаки одинаковых размеров. Но эмбеддинги (ViTEmbeddings) и энкодер (ViTEncoder) выполняют разные функции:
-
Эмбеддинги (ViTEmbeddings)
Отвечают за преобразование входного изображения в последовательность токенов (векторов признаков); -
Энкодер (ViTEncoder)
Отвечает за обработку последовательности эмбеддингов с помощью, извлекая из них сложные пространственные и контекстные зависимости.
Исходя из этого будем рассматривать для проекции в линзу выходные слои энкодера (ViTEncoder).
Извлечение скрытых представлений
В библиотеке transformers возвращение скрытых состояний можно получить при помощи простого атрибута при инференсе.
Синтаксис примерно такой:
model = Net()
outputs = model(**inputs, output_hidden_states=True)
Извлечем их из модели.
outputs = model(**input0, output_hidden_states=True) # Extracting the predicted class. predicted_class_idx = torch.sigmoid(outputs.logits).argmax(-1).item() labels = model.config.id2label # indexes and lanels dict print(f"Predicted: {labels[predicted_class_idx]}") print(f'Количество извлеченных скрытых состояний: {len(outputs.hidden_states)}') # OUT #Predicted: brown bear, bruin, Ursus arctos #Количество извлеченных скрытых состояний: 13
И так, мы получили верный прогноз — модель видит медведя и 13 скрытых состояний. Здесь важно:
-
Первое (‘outputs.hidden_states[0]’) — выход части эмбеддинга
model.vit.embeddings
-
Состояния 2-13 — выходы частей энкодера — все до применения к ним самого последнего слоя компоненты
model.vit
—layernorm
Нулевое отбросим. Нас будут интересновать с 1-го по 13е. И посмотрим, как извлекать проекции.
Посмотрим, как ивзлекать проекции. С одной стороны, можно создавать дочернюю модель с меньшим числом слоёв энкодера. Такой подход описан в статье здесь (там тоже есть код).
Мы же пойдем иначе, скрытые состояния мы уже извлекли. Теперь, по опредлению, нам нужно построить:
То есть, пропустить каждое скрытое состояние через классификационные веса. Они находятся в модуле модели classifier
. Извлечем его
classification_layer = model.classifier
И сейчас мы готовы построить линзы. Однако, если посмотреть на размерность скрытых состояний, каждое скрытое состояние имеет размерность [1, 197, 768]
, применяя линейный слой получится вектор (1, 197) — один класс на каждый патч (если вы не знаете термин пачта — это, грубо говоря, компонента, на которую модель разбивает изображение для извлечения признаков. Подробно тут).
С чем это связано?
В модели ViT для классификации используется CLS токен. Он соответствует 0-му из патчей. Мы можем взять либо его, либо медиану/среднее по патчам.
Реализуем оба способа и посмотрим на результат.
cls_hidden = hidden_states[-1][:, 0, :] # extracting CLS token median_hidden = torch.median(hidden_states[-1], dim=1)[0] # extracting median # extracting classification result (logits) classification_result_cls = classification_layer(cls_hidden) classification_result_median = classification_layer(median_hidden) # extracting classification result (probas) predicted_class_cls = torch.softmax(classification_result_cls, dim=-1) predicted_class_median = torch.softmax(classification_result_cls, dim=-1) print(f'Predicted class on CLS token: {predicted_class_cls.argmax()}') # 294 print(f'Predicted class on median: {predicted_class_median.argmax()}') # 294
Для последнего скрытого состояния результаты одинаковы. И на этом моменте у нас есть всё, чтобы пробежать по всем скрытым состояниям — от понимания теории, до понимания практических тонкостей! Напишем для этого функцию.
def get_hidden_state_result(hidden_state, true_class_idx, mode='CLS'): if mode == 'CLS': hidden_state = hidden_state[:, 0, :] # Extract the hidden state corresponding to the [CLS] token (first token in the sequence) if mode == 'median': # Take the median across the patches (dim=1 refers to the second dimension which is the patches) hidden_state = torch.median(hidden_state, dim=1)[0] classification_layer_result = classification_layer(hidden_state) # Apply classification layer predicted_class_probas = torch.softmax(classification_layer_result, dim=-1) # Apply softmax to get class probabilities # Get the predicted class index and probability predicted_class_idx = predicted_class_probas.argmax().item() predicted_class_proba = predicted_class_probas.max().item() # Get the rank of the true class #Sort the predicted probabilities in descending order and find the rank of the true class # by looking for the true class index in the sorted tensor. _, indexes_sorted = torch.sort(predicted_class_probas, descending = True) final_label_rank = torch.where(indexes_sorted == true_class_idx)[-1].item() return predicted_class_proba, predicted_class_idx, final_label_rank
Функция возвращает вероятность прогнозированного класса на представлении , индекс этого класса и ранг — порядковый номер в отсортированном векторе вероятностей. Пример запуска функции на медведе:
true_class_idx = 294 bear_image_probas = [] bear_image_predictions = [] bear_image_true_ranks = [] for hidden_state in hidden_states: predicted_class_proba, predicted_class_idx, final_label_rank = get_hidden_state_result(hidden_state, true_class_idx, mode='CLS') bear_image_probas.append(predicted_class_proba) bear_image_predictions.append(predicted_class_idx) bear_image_true_ranks.append(final_label_rank) bear_table = [bear_image_probas, bear_image_predictions, bear_image_true_ranks]
Результат будет выглядеть так.
fig, ax = plt.subplots(figsize=(20, 6)) sns.heatmap(bear_table, annot=True, ax=ax, cbar=False, fmt='g') ax.set_xticklabels([f'layer {i+1}' for i in range(12)]); ax.set_yticklabels(['Predicted class probas', 'Predicted index', 'True label rank']); ax.set_title('Results for bear image');

Так как на одном изображении не совсем корректно делать выводы, то я построила для всех (это есть в полной тетради). Получается, что:
-
На некоторых примерах прогноз может стабилизироваться на ранних слоях (4-5 для horses, bird).
-
Истинный (по прогнозу) класс оказываются в топ-3, на последних слоях, но не для всех изображений это верно (контрпример — sheep)
-
Уверенность (выраженная вероятностью) в пронозе резко возрастает на двух последних слоях (третий с конца может быть как стабильным по уверенности, так и нет (например, cow c 0.15 и bird с 0.9)).
На этом нехитрый анализ методом Logit Lens завершается. Надеюсь, мне удалось познакомить вас с методом и донести его идею. Благодарю вас за время, уделенное туториалу! Надеюсь, также он вас на что-то вдохновил
Полный ноутбук здесь (на русском и английском). В репозитория вы также можете найти другие туториалы (почти по всем есть статьи на Хабре от меня же).
Хороших и красивых проектов!
Ваш Дата-автор!
ссылка на оригинал статьи https://habr.com/ru/articles/891352/
Добавить комментарий