Дело было вечером, делать было нечего. Я сидел за ноутом и разбирал новую идею Deepseek Engram: Лян Ванфень собрал вместе хеш-таблицы и почти-линейный трансформер — получилось дешево и сердито.
Однако есть в Engram один недостаток — он требует много RAM (каламбурчик, хаха). А хотелось архитектуру, на инференс которой не придется скидываться всем поселком.
Небольшой ликбез
Engram, по сути, перешивает токены и добавляет к ним факты. Реализовано это довольно хитро, через хеш-функцию, O(1) по сложности. Благодаря такой пристройке трансформер уделяет больше внимания на грамматику и связь слов в предложении.
Основная идея
А что если вместо дорогого по вычислениям Engram взять простые свертки? Они дешевые, быстрые и могут запомнить базовые факты.
Именно об этом я и подумал. И тут же решил проводить тесты.
К сожалению у меня нет в гараже кластера на 8xH200 (да и гаража у меня нет), поэтому обучить что-то большое не получится. Однако для быстрого эксперимента хватит Colab и его Т4 16Гб.
Архитектура модели
За пару минут набросал схему в Obsidian. Теперь про каждый блок отдельно
RMSNorm
Базовый слой нормализации, в современный трансформерах без него будет тяжко.
Conv1D
Ключевое нововведение. Depthwise и kernel = 3 обогащают токены и перемешивают их. Чтобы сетка не ‘поглядывала’ реализовал каузальные свертки.
MQA
Довольно быстрая и дешевая реализация классического Self-Attention, но все еще не линейная или реккурентная архитектура.
FFN + SwiGLU
Два главных компонента: новая функция активации и необычное расширение в линейном слоев — x8/3 на 3 слоя вместо устоявшегося х4 на 2 слоя (позволяет сохранить то же кол-во параметров при большем кол-ве операций).
Эта комбинация отлично показала себя в моделях Llama, где была применена впервые.
Все это решил обозвать NormIs-1. Логики в названии нет абсолютно никакой.
Меньше слов — больше кода
Не стал что-то менять в нормализации и сделал самую простую версию.
сlass RMSNorm(nn.Module): def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): return F.normalize(x, dim=-1) * self.scale * self.g
Так-же сделал с FFN — просто и понятно
class SwiGLU(nn.Module): def __init__(self, dim): super().__init__() hidden_dim = int(dim * 4 * 2 / 3) self.w_gate = nn.Linear(dim, hidden_dim, bias=False) self.w_val = nn.Linear(dim, hidden_dim, bias=False) self.w_out = nn.Linear(hidden_dim, dim, bias=False) def forward(self, x): gate = F.silu(self.w_gate(x)) val = self.w_val(x) return self.w_out(gate * val)
Наивная реализация сверток. Спойлер — простой forward() потом вышел мне боком из-за медленной памяти.
class CausalConv1D(nn.Module): def __init__(self, dim, kernel_size=3): super().__init__() self.pad = kernel_size - 1 self.conv = nn.Conv1d(dim, dim, kernel_size, groups=dim) def forward(self, x): x = x.transpose(1, 2) x = F.pad(x, (self.pad, 0)) x = self.conv(x) x = x.transpose(1, 2) return x
А вот и все ноутбуки с обучением (ссылки на Colab):
Кастомная архитектура
MHA + MQA
Метрики
Один из самых важных вопросов — а как вообще оценить NormIs-1? С чем его сравнивать? Какие метрики измерять?
Введем двух дополнительных кандидатов — трансформер на MQA и на MHA без сверток.
MHA считается лучшим по качеству, но он-же медленнее всего. Это Topline
MQA — топ по скорости, но может терять в качестве. Это Baseline.
Метрики ‘интеллекта’ модели — Loss (Cross-Entropy) и Perplexity. Метрики скорости — время обучения и TPS (tokens per second).
Моя цель — усидеть на двух стульях: получить интеллект уровня MHA, не потеряв при этом в скорости генерации MQA. Если NormIs-1 догонит Topline по качеству, оставшись таким же быстрым — это победа.
Сравнение
Чтобы эксперимент был честным, я зафиксировал все гиперпараметры. Изменялась только архитектура внутреннего блока.
Конфигурация:
-
Датасет: TinyStories. Идеален для микро-моделей: в нем простая лексика, но строгие требования к грамматике и логике.
-
Токенизатор: Свой собственный, обученный на 8К токенов. Это позволило не раздувать матрицу эмбеддингов и сфокусировать ‘мозги’ модели на смысле, а не на хранении словаря.
-
Геометрия:
model_dim = 128,context = 256. Компактно, но достаточно для коротких рассказов. -
Обучение:
steps = 5000,batch = 64.
Итого на претрейн — токенов.
Запустил обучение и ушел пить чай. По моим расчетам каждая модель училась бы не более получаса.
И вот наступил момент Х, пора сравнивать.
|
Сравнение |
MHA Topline |
MQA Baseline |
NormIs-1 |
|---|---|---|---|
|
Параметры |
1.84M |
1.75M |
1.75M |
|
Время обучения |
24:04 |
24:15 |
25:03 |
|
Val. Perplexity |
7.9 |
8.24 |
7.94 |
|
Val. Loss |
2.0668 |
2.1095 |
2.0713 |
|
Tokens/sec |
362 |
339 |
202 |
Качество довольно хорошее. NormIs остался на уровне MHA, имея меньше параметров.
Но вот скорость обучения и инференса выглядит печально. А все из-за наивной реализации сверток. Граф вычислений на PyTorch должен создавать новый CUDA Kernel для каждой свертки.
Из-за этого модель значительно медленнее при инференсе, а при обучении это не так заметно. Думаю, если написать нормальный движек, то NormIs получит свои 300т/с+
Вот ссылка на папку — там графики падения лосса и примеры генерации модели.
Выводы и идеи
Результат хороший, но не прорывной. Дальше я хочу попробовать эту же конфигурации, но на большем масштабе (20М+ параметров) и на сложной задаче (например, Fineweb-Edu).
Спасибо что дочитали статью. Это мой первый опыт написания подобных текстов.
Буду рад если получится дать фидбек на мои решения. Я в ML недавно, только учусь. Будет интересно послушать людей с опытом.
ссылка на оригинал статьи https://habr.com/ru/articles/1030492/