В мире глубокого обучения существует наивный миф: «Если твоя модель недостаточно умная, просто накинь еще пару десятков слоев».
На бумаге residual связи (те самые плюсики в коде: x=x+f(x) ) должны позволять нам строить сети бесконечной глубины, спасая градиенты от затухания. Но любой, кто пытался с нуля обучить трансформер слоев на 80, знает жестокую правду: сеть просто отказывается сходиться. Loss взрывается в первые же эпохи, или модель навсегда застревает на субоптимальном плато.
Долгие годы эту проблему лечили «костылями»: сложными схемами разогрева learning rate, танцами с инициализацией весов и стохастической глубиной. Пока исследователи из Meta AI не предложили изящный хак под названием LayerScale.
Давайте залезем под капот и посмотрим, почему глубокие Трансформеры умирают, и как LayerScale их воскрешает.
Проблема: Токсичность residual stream
Посмотрим на классический блок трансформера:
xout=xin+Block(LayerNorm(xin))
Магистраль X, которая проходит через всю сеть от первого до последнего слоя, часто называют residual stream (Остаточный поток). Это главная информационная шина модели.
В чем проблема? Каждый новый блок (Attention или FFN) берет данные из шины, как-то их обрабатывает и вливает результат обратно в шину через операцию сложения.
На ранних этапах обучения веса блоков инициализированы случайным образом. Это значит, что каждый блок вливает в магистраль чистейший математический шум.
-
Слой 1 добавил шум. Дисперсия сигнала выросла.
-
Слой 2 получил зашумленный сигнал, умножил его на свои случайные матрицы и добавил еще больше шума.
-
К 50-му слою изначальный сигнал (эмбеддинги токенов) полностью тонет в хаосе дисперсии.
LayerNorm пытается спасти ситуацию на входе в каждый следующий блок, но он не спасает саму магистраль. Когда loss функция на самом верху сети пытается прокинуть градиенты вниз, она видит перед собой бушующую реку дисперсии. Оптимизатор сходит с ума.
Решение: LayerScale и концепция мьюта
Идея LayerScale поражает своей простотой. А что, если на старте обучения мы выключим звук у всех слоев, кроме самых первых?
Мы добавляем один обучаемый вектор λ (диагональную матрицу) той же размерности, что и наш вектор x. Мы умножаем выход блока на этот вектор до того, как прибавить его к шине:
xout=xin+λ⊙Block(LayerNorm(xin))
А теперь главная магия: мы инициализируем λ микроскопическими значениями, например 10−4 или 10−5 (10−6 для очень глубоких сетей).
Как мы обманываем оптимизатор
Посмотрите, что происходит на первой итерации обучения.
Поскольку λ≈0 , выход любого блока умножается на ноль. Уравнение схлопывается до:
xout≈xin+0
Математически, ваша 100-слойная махина в начале обучения ведет себя как сеть из нуля слоев (чистая функция идентичности). Сигнал пролетает от входа до выхода без единого искажения. Градиенты текут идеально гладко, как по автобану.
Дальше в дело вступает градиентный спуск. Оптимизатор видит параметр λ и понимает: «Ага, если я начну понемногу увеличивать λ в определенных слоях, loss начнет падать».
Сеть начинает постепенно «пробуждать» слои. Она сама решает, какому блоку внимания дать «громкость», а какой оставить заглушенным. Мы больше не вливаем хаос в магистраль. Мы приоткрываем кран ровно на ту величину, которую сеть способна переварить без взрыва градиентов.
Реализация на Pytorch
Внедрить это в свою архитектуру проще простого. Вот как выглядит этот слой:
import torchimport torch.nn as nnclass LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace # Создаем обучаемый параметр: вектор размерности dim self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): # Умножаем каждый канал на свой вес return x.mul_(self.gamma) if self.inplace else x * self.gamma# Пример использования внутри блока:class TransformerBlock(nn.Module): def __init__(self, dim): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = Attention(dim) self.ls1 = LayerScale(dim, init_values=1e-5) # <-- layer scale self.norm2 = nn.LayerNorm(dim) self.mlp = MLP(dim) self.ls2 = LayerScale(dim, init_values=1e-5) # <-- layer scale def forward(self, x): # Сигнал проходит через блок, глушится LayerScale, и только потом плюсуется x = x + self.ls1(self.attn(self.norm1(x))) x = x + self.ls2(self.mlp(self.norm2(x))) return x
Разница с LayerNorm
Часто возникает вопрос: «Зачем нам LayerScale, если у нас уже есть LayerNorm?».
Это разные инструменты для разных задач:
-
LayerNorm работает внутри пайплайна блока. Он нормирует сигнал, чтобы матрицам (Wq, Wk, Wv ) было комфортно с ним работать. Он не защищает внешнюю магистраль.
-
LayerScale работает на выходе из пайплайна. Это гейт (вентиль), который контролирует, насколько сильно этот блок имеет право изменить глобальный residual stream.
Кстати, λ это по-канальный вектор. Сеть может выучить, что для 5-го канала громкость должна быть 1.0, а для 128-го канала остаться около нуля. Это дает колоссальную гибкость в маршрутизации признаков.
Заключение
LayerScale это идеальный пример того, как глубокое понимание градиентной механики побеждает грубую вычислительную силу. Вместо того чтобы придумывать зубодробительные схемы оптимизаторов или сжигать мегаватты на подбор Learning Rate, мы добавляем один вектор на слой и позволяем топологии сети собирать саму себя в процессе обучения.
Если вы пишете архитектуру с нуля, или пытаетесь обучить что-то глубокое (особенно в компьютерном зрении, где эта проблема стоит острее всего), добавьте LayerScale. Вы удивитесь, насколько стабильнее станет ваш loss.
ссылка на оригинал статьи https://habr.com/ru/articles/1036484/