После написания статьи про NormIs-1 я решил углубиться в тему оптимизации модели. Архитектура была неплохой и показала адекватные результаты на метриках интеллекта, но скорость сильно проседала. Проблема была в Depthwise Conv. Между блоком внимания и FFN стоял небольшой блок сверток и тормозил все вычисления. Именно его я и решил ускорить.
Делать целую языковую модель с полноценным вниманием возможности нет. Проблемы нестабильного обучения, взрывающийся лосс или сломавшийся DataLoader — это все не сегодня.
Нет, все будет ‘проще’ — мелкая CNN + кастомное MLX-ядро для инференса + бенчмарки скорости
Пациента на стол
В качестве подопытного решил взять небольшую CNN на 103К параметров и обучить ее на MNIST с нуля. Ничего сложного нет, с подобной задачи начинается любой курс по DL.
Обучил, сохранил, обрадовался — а руки-то помнят. Красивые графики лосса и точности прилагаются.
Сама модель, как я сказал выше, ничего сложного из себя не представляет. Несколько слоев свертки и MLP-голова для финальной классификации.
А вот и базовый класс сетки:
class DepthwiseSeparableConv(nn.Module): def __init__(self, in_c, out_c, stride=1): super().__init__() self.use_skip = stride != 1 or in_c != out_c if self.use_skip: self.skip = nn.Sequential( nn.Conv2d(in_c, out_c, 1, stride, bias=False), nn.BatchNorm2d(out_c), ) self.dw = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False) self.bn1 = nn.BatchNorm2d(in_c) self.pw = nn.Conv2d(in_c, out_c, 1, bias=False) self.bn2 = nn.BatchNorm2d(out_c) self.relu = nn.ReLU(inplace=True) def forward(self, x): identity = x x = self.relu(self.bn1(self.dw(x))) x = self.bn2(self.pw(x)) if self.use_skip: identity = self.skip(identity) x = self.relu(x + identity) return x
Все целиком:
class DepthwiseMNIST(nn.Module): def __init__(self, num_classes=10): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(1, 48, 3, padding=1, bias=False), nn.BatchNorm2d(48), nn.ReLU(inplace=True), ) self.stages = nn.Sequential( DepthwiseSeparableConv(48, 96, stride=2), DepthwiseSeparableConv(96, 96, stride=1), DepthwiseSeparableConv(96, 192, stride=2), DepthwiseSeparableConv(192, 192, stride=1), ) self.avgpool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(192, num_classes) def forward(self, x): x = self.conv1(x) x = self.stages(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x
Подробнее модель можно изучить в scr/train_mnist.py — тут и архитектура, и загрузка/обработка данных, и основная функция обучения, и логгирование с графиками.
Первый этап пройден, теперь задача сложнее: надо написать кастомное ядро на MLX для ускорения работы сверток.
Быстрый ликбез
Обычно для обучения нейросетей на маке используют PyTorch с бэкендом MPS (torch.device('mps')). Внутри MPS лежат готовые функции от Apple: свертки, линейные слои, реализация разных типов внимания и прочая база. Но шаг вправо, шаг влево — и PyTorch начинает собирать операцию из медленных кусков, гонять данные туда-сюда и жутко тормозить.
Чтобы обойти эти ограничения, Apple выкатила MLX — фреймворк, который дает прямой доступ к движку Metal.
Кастомные ядра в MLX пишутся на MSL (Metal Shading Language). Базируется он на стандартах С++14/С++17, но со специфическими фишками для GPU. MLX берет твой код на MSL, на лету компилирует его под видеочип и выполняет операцию с Unified Memory вообще без задержек на лишнее копирование.
Вот как выглядит простейший шаблон такого ядра на MSL, который считает операцию для каждого элемента (Elementwise):
«`
#include <metal_stdlib>using namespace metal;kernel void my_custom_op( device const float* in [[buffer(0)]], device float* out [[buffer(1)]], uint index [[thread_position_in_grid]], uint grid_size [[threads_per_grid]]) // Metal сам знает общий размер сетки{ // Простая и лаконичная проверка границ if (index >= grid_size) return; out[index] = in[index] * in[index];}
На выходе получаем прирост в скорости — убрали лишние абстракции и заставили железо Apple Silicon крутить нашу кастомную математику напрямую.
Ускоряем скорость
Вернемся к нашей модели. Веса лежат в output/best_model.pth, а логика ядра более-менее понятна.
Сразу стоит сказать о том, что
torch.compile(model, mode="max-autotune")— возможно, но для меня это скучно. Для Depthwise Conv написаны глубокие производительные ядра, и компилятор без проблем может сделать оптимизированную версию модели. Однако у меня чисто спортивный интерес — а смогу ли я? Поэтому легкий путей искать не буду.
Однако тут-же возникает сразу несколько проблем:
-
Разница в формате данных. В .pth-файле тензоры активаций лежат в формате NCHW (Number/Batch, Channels, Height, Width), а веса свёрток — OIHW (Output channels, Input channels, Height, Width). Но библиотека MLX ожидает данные в формате NHWC, а веса — в OHWI. Если неправильно перенести веса или невнимательно транспонировать — будет плохо. Будет очень плохо.
-
Особенность ленивых вычислений Компилятор MLX ленив (как и я). Пока явно не вызвать
mlx.core.eval()— он ничего делать не будет. Замеры времени могут показать 0.0001 секунд, однако это не реальные данные. Сама свертка еще даже не запустилась, а компилятор понял, что результат нигде не используется — значит и считать этот результат не надо. -
Проблема Memory-Bound и latency запуска ядер (Launch Overhead). Наша маленькая CNN-ка на 103К параметров весит копейки, а её Arithmetic Intensity на Depthwise-слоях болтается в районе 0.87-2.2 FLOPs/byte. При Ridge point чипа Apple M1 около 39 FLOPs/byte это значит, что модель намертво упирается в пропускную способность памяти (Memory bandwidth). Если считать каждый чих отдельно (Conv -> BN -> ReLU), GPU будет дольше ждать ленивую выгрузку промежуточных тензоров в глобальную память и страдать от оверхеда на запуск Metal-команд, чем реально вычислять математику.
Arithmetic Intensity (арифметическая интенсивность) — это отношение количества вычислительных операций (FLOP) к объёму прочитанных или записанных данных (байтам) в программе или алгоритме
Ridge point (точка перегиба, критическая точка) — это аппаратная характеристика вычислительной системы, определяющая минимальную арифметическую интенсивность (количество операций FLOP на 1 байт данных), которая необходима алгоритму для достижения пиковой вычислительной мощности процессора (CPU, GPU или NPU).
Итак, погнали! Самый важный файл тут — scr/mlx_model.py. Внутри него прописана логика ядра, которое ускорит (наверное) мою сетку. Давайте разбираться.
Внимательно транспонируем
def perm_conv(arr): """PyTorch (C_out, C_in, H, W) -> MLX (C_out, H, W, C_in).""" return mx.array(arr.transpose(0, 2, 3, 1))
Аккуратная работа с памятью — наше все. Не забываем про правильно транспонирование, но долго на этом моменте не задерживаем — дальше больше.
Найди меня, если сможешь
def load_pt_weights(pt_path, mlx_model): pt = torch.load(pt_path, map_location="cpu", weights_only=True) flat = {} def add(key, arr, do_perm=False): flat[key] = perm_conv(arr) if do_perm else mx.array(arr)
Мало просто поменять формат тензоров, нужно ещё заставить MLX понять, какой тензор из файла .pth к какому слою новой модели относится. В PyTorch веса лежат в плоском словаре (state_dict), а MLX строит из них красивое вложенное дерево (tree_unflatten).
Разработчики PyTorch и MLX видят мир по-разному. Из-за этого архитектура, написанная на двух фреймворках, имеет абсолютно разные имена ключей:
-
Индексы против имен
В PyTorch первый блок упакован в
nn.Sequential. Поэтому свёртка там называетсяconv1.0.weight, а батч-норм —conv1.1.weight. В MLX это два раздельных явных слоя:conv1.weightиbn1.weight. -
Структурный хаос в циклах
Внутри цикла по стадиям (
for i in range(4)) начинается сущий ад. В PyTorch слои лежат в плоском спискеstages.0.dw.weight, а в MLX они вложены в.layers, превращаясь вstages.layers.0.dw.weight. -
Спрятанный Skip Connection
Самое весёлое — остаточные связи, они же скипы. В моей модели они есть только на 0 и 2 стадиях. В PyTorch это опять
Sequential(skip.0.weight), а в MLX — кастомное имяskip_conv.weight.
Функция load_pt_weights решает эту проблему, перебирая названия весов и сопоставляя их.
Последние штрихи
Классы DepthwiseSeparableConv и DepthwiseMNIST создают в памяти пустую болванку модели.
Через count_params считаем параметры — должно получиться столько-же, сколько и при тренировке.
А get_model собирает весь наш конструктор воедино — читает файл с весами, правильно его разворачивает и записывает в болванку.
А теперь — хардкор на Metal
Именно ради этого мы здесь все и собрались. Всего 38 строк, но зато каких! Ладно, нагнал пафоса, теперь можно продолжать.
Шаг 1. Кто я и где я? (Индексы потоков)
uint nw = thread_position_in_grid.x;uint row = thread_position_in_grid.y;uint col = thread_position_in_grid.z;
Каждому отдельному потоку (треду) на GPU выдаются свои координаты в трехмерной сетке. Из координаты nw мы вытягиваем индекс картинки в батче n и индекс конкретного канала c (ведь у нас Depthwise-свёртка, канал обрабатывается изолированно). Координаты row и col — это пиксель (строка и столбец) на нашей выходной картинке.
Шаг 2. Виртуальное окно (Свёртка 3х3)
int h_in = (int)row * stride - 1;int w_in = (int)col * stride - 1;
Здесь мы вычисляем, куда на оригинальной входной картинке падает левый верхний угол нашего ядра свёртки 3х3. Минус один (-1) берется из-за того, что у нас padding=1, из-за чего свертка неявно заступает за границы изображения.
Шаг 3. Перемножаем и складываем (Математика)
for (int kh = 0; kh < 3; kh++) { for (int kw = 0; kw < 3; kw++) { ... }}
Два цикла бегут по окну 3х3. Внутри стоит жесткая проверка: если мы вылезли за границы картинки (в те самые области паддинга), мы просто ничего не делаем (считаем, что там виртуальный ноль). Если мы внутри картинки, то считаем плоский индекс в памяти inp_idx (привет, channel-last формат NHWC), забираем значение пикселя, умножаем на вес из w и докидываем в общую копилку sum_val.
Шаг 4. Бесплатный BatchNorm и ReLU
float normed = (sum_val - running_mean[c]) / metal::sqrt(running_var[c] + eps);float activated = normed * gamma[c] + beta[c];activated = metal::max(0.0f, activated);
Значение sum_val всё еще лежит в сверхбыстром регистре процессора. Мы не отправляем его в медленную общую память. Мы прямо здесь вычитаем среднее, делим на корень из дисперсии, умножаем на веса батч-норма (gamma, beta) и сразу же обнуляем всё, что меньше нуля (эффект ReLU).
Шаг 5. Выгрузка
out[out_idx] = activated;
И только теперь, когда получился финальный, отмытый батч-нормом и активированный пиксель, мы один-единственный раз лезем в глобальную память устройства и записываем его туда.
Теперь, когда большая часть оптимизирована, можно запускать тесты. А что вобще измерять?
Тесты, результаты и выводы
Итак, у нас есть 3 кандидата:
-
PyTorch MPS (эталон). Модель написана на PyTorch, работает через MPS. Это бейзлайн, без каких-либо изменений.
-
MLX (baseline) — та же архитектура, портированная на MLX вручную (channel-last). Работает через штатные
mlx.nn.Conv2d/mlx.nn.BatchNorm. Показывает, сколько MLX выдаёт без оптимизаций. -
MLX (fused) — то же самое, но
depthwise 3x3+BatchNorm+ReLUзаменены на одно кастомное Metal-ядро (один launch вместо трёх, промежуточные тензоры не пишутся в память).
Измерял 2 основные метрики: Latency (время ожидания ответа модели на один запрос) и Throughput (кол-во запросов, которое получается обработать за определенно время).
Так-же дополнительно записал метрики и построил графики:
-
P50 (медиана) latency — основная, робастная к выбросам
-
P95 latency — tail latency, критична для SLA/real-time
-
CV (Coefficient of Variation = σ/μ) — стабильность/разброс замеров
-
Speedup factor = latency_baseline / latency_fused — во сколько раз быстрее
-
Throughput scaling efficiency = throughput(BS) / (BS × throughput(1)) — насколько хорошо масштабируется
-
FLOPs per pass — теоретическая вычислительная ёмкость (16.9M FLOPs)
-
Arithmetic intensity = FLOPs/byte — определяет характер узкого места
-
Timeline (scatter) — все 200 trials по порядку со скользящим средним. Наглядно ловит throttle
-
CDF latency — кумулятивное распределение. Сразу видно P50/P95/P99
-
Violin plot — полная форма распределения latency на каждом BS
-
Bar chart — P50 ± P95 (error bar), сравнение бэкендов рядом
-
Speedup chart — коэффициент ускорения fused vs baseline/PT по BS
-
Roofline model — achieved TFLOPS vs peak (2.66 TFLOPS / 68 GB/s). Наглядно доказывает, что модель memory-bound
Тесты проводил на разных батчах: 1, 16, 64, 128, 512.
Всего провел 3600 измерений. Анализировать каждую строчку всех таблиц смысла нет. Чистые данные лежат в output/per_trial.npz, а обработанные (сведены в удобные таблицы и сделаны выводы/предположения) — в conclusions/benchmark_result.md. В этой статье выделю самые интересные наблюдения.
Есть ощутимый прирост

Да, это самая главная новость. На BS=1 MLX-Fused быстрее конкурентов в полтора раза. Это на 50% лучше! Throughput поднялся с 339p/s до 522p/s, а latency упал с 2.947ms до 1.917ms.
Проблемы с тротлингом на большом батче

На других размерах батча цифры разняться — где-то скорость выше, где-то ниже. Отчасти это можно объяснить значениями метрики стабильности замеров.
|
BS |
PT CV |
MLX CV |
Fused CV |
|---|---|---|---|
|
1 |
4.95% |
68.31% ⚠ |
32.34% ⚠ |
|
16 |
32.02% ⚠ |
38.13% ⚠ |
16.95% ⚠ |
|
64 |
30.84% ⚠ |
14.60% ⚠ |
33.77% ⚠ |
|
128 |
41.22% ⚠ |
8.72% |
22.74% ⚠ |
|
256 |
19.99% ⚠ |
3.90% |
10.10% ⚠ |
|
512 |
17.22% ⚠ |
1.87% |
19.39% ⚠ |
MLX baseline стабилен на BS ≥ 128 (CV < 10%). Скорее всего жесткий контроль температур заставляет проц сбрасывать частоты при подозрении на аномальную активность. Если проводить измерения с остановками можно получить немного другие цифры. Пока идея в бэклоге.
Идеи и следующие шаги
Делал я это все не просто так. Если получится оптимизировать CNN, то можно переходить на более сложный уровень — LLM со сверточными слоями. Там тоже непаханное поле и много идеи — все надо проверить и применить.
Пока я продолжаю работу со сверточной сеткой, желая добиться ЕЩЕ более хороших результатов. Идеи для этого такие:
-
Использовать
float16как основной тип данных (надо будет подумать над правильной реализацией BN) -
Экспортировать веса в
.safetensorsи не мучиться с питонячьим оверхедом. -
Зафьюзить всю сеть в один kernel launch — это может дать еще больший прирост.
Актуальный код проекта можно посмотреть ТУТ.
Если есть идеи как улучшить/ускорить мои методы — пишите комментарии или создавайте issues на гите. Буду рад любой обратной связи.
Спасибо за внимание,
morginalium8
ссылка на оригинал статьи https://habr.com/ru/articles/1052790/