Проклятие адаптивности: почему живучесть нейросетей ваш главный враг (и как я случайно ампутировал трансформер)

от автора

Мы привыкли восхищаться тем, как нейронные сети умеют адаптироваться. Они находят паттерны в шуме, обходят локальные минимумы и выжимают максимум из грязных данных. Но у этой сверхспособности есть темная сторона, о которой редко говорят в туториалах.

Сверх адаптивность нейросетей это худший кошмар инженера.

Если вы пишете архитектуру с нуля или оптимизируете SOTA алгоритмы, сеть будет изо всех сил пытаться скрыть от вас ваши же баги. Оптимизатору плевать на вашу изящную задумку. Его цель минимизировать Loss любой ценой. Если вы перекроете ему главный мост, он найдет брод, переплывет реку, построит плот из мусора и всё равно доставит Loss к нулю. А вы будете смотреть на красивый график и думать, что вы гений.

Позвольте рассказать вам историю моего главного архитектурного фиаско, которое доказало мне, что верить нельзя даже падающему лоссу.

Идеальное преступление: Triton, RoPE и потерянный градиент

Я работал над оптимизацией инференса и обучения для кастомного Трансформера. Все мы знаем, что стандартные реализации слоев в Pytorch могут быть медленными, поэтому я решил написать собственное ядро на triton для RoPE

Я написал forward пас. Он работал молниеносно, тесты на размерности проходили, профилировщик GPU показывал великолепную утилизацию памяти. Затем я написал backward пас для вычисления градиентов.

Запустил обучение. Loss начал падать. Графики выглядели поразительно хорошо. Модель сходилась быстро, тексты генерировались осмысленно.

А потом, спустя время, при дебаге я решил проверить градиентные графы. И у меня похолодела кровь.

Анатомия катастрофы

Из-за глупой ошибки в backward функции triton ядра я забыл пропустить градиент обратно к матрицам проекций Wq и Wk

Давайте осмыслим масштаб трагедии. Матрицы Wq и Wk это суть механизма внимания. Они решают, какое слово должно смотреть на какое. Из-за моего бага они вообще не получали градиентов. На протяжении всего обучения Wq и Wk оставались замороженными в состоянии случайной инициализации!

Как должна была повести себя нормальная программа? Она должна была выдать мусор. Она не должна была ничему научиться. Внимание, использующее случайные проекции, это просто рандомный генератор шума.

Почему она сошлась? Эффект «Bag of Words» на максималках

Так почему же Loss падал, а модель генерировала адекватный текст? В этом и заключается ужасающая сила нейросетей.

Оптимизатор понял: «Так, позиции слов я больше не вижу. Механизм Внимания выдает мне случайную матрицу связей. Что у меня осталось?» У него остались матрицы Value и могучий FFN

Нейросеть мгновенно перестроилась. Раз она не могла понимать синтаксис и порядок слов (потому что Q и K были сломаны, а RoPE мертв), она решила стать самым продвинутым в мире мешком слов.

Она начала зверски оптимизировать матрицы V и слои FFN, чтобы предсказывать следующее слово исключительно на основе наличия других слов в контексте, полностью игнорируя их порядок

  • Видит токены «Король», «Мужчина», «Женщина» -> выдает «Королева». Ей было неважно, в каком порядке они стоят.

  • Случайные веса просто работали как фиксированный случайный роутер, а сеть адаптировала свои слои так, чтобы расшифровывать этот случайный шум.

Модель сходилась поразительно хорошо. Она выучила грамматику, факты, стиль. Но это был инвалид, который научился бегать на руках, потому что я случайно ампутировал ему ноги. И он бегал так быстро, что я даже не замечал отсутствия ног.

Выводы, написанные кровью (и часами на GPU)

  1. Loss это лжец. Падающий loss означает только то, что оптимизатор нашел градиентный спуск. Он не означает, что ваша архитектура работает так, как вы задумали.

  2. Нейросеть скроет ваши баги. Если вы сломали skip connection, сеть увеличит веса основного пути. Если вы сломали position embeddings, сеть выучит статистику частотности. Если вы перепутали view и reshape, разрушив структуру тензора, сеть просто выучит новую, исковерканную топологию данных.

  3. Проверяйте градиенты, а не только loss. Если вы пишете кастомные слои (особенно на triton), первый ваш тест это не forward, это проверка того, что .grad не содержит нулей и имеет адекватную норму.

Заключение

Мы любим смеяться над классическими программистами с их тестами и строгой типизацией. Но в мире ML цена бага гораздо выше.

В следующий раз, когда ваша новая гениальная архитектура «поразительно хорошо сойдется» с первой попытки, не спешите открывать шампанское. Возможно, где-то в недрах autograd’а ваша модель прямо сейчас изобретает костыль, чтобы обойти вашу ошибку.

Будьте параноиками. Проверяйте градиенты.

ссылка на оригинал статьи https://habr.com/ru/articles/1036416/