Как я свертки ускорял

от автора

Вовка плохого не посоветует

Вовка плохого не посоветует

После написания статьи про NormIs-1 я решил углубиться в тему оптимизации модели. Архитектура была неплохой и показала адекватные результаты на метриках интеллекта, но скорость сильно проседала. Проблема была в Depthwise Conv. Между блоком внимания и FFN стоял небольшой блок сверток и тормозил все вычисления. Именно его я и решил ускорить.

Делать целую языковую модель с полноценным вниманием возможности нет. Проблемы нестабильного обучения, взрывающийся лосс или сломавшийся DataLoader — это все не сегодня.

Нет, все будет ‘проще’ — мелкая CNN + кастомное MLX-ядро для инференса + бенчмарки скорости

Пациента на стол

В качестве подопытного решил взять небольшую CNN на 103К параметров и обучить ее на MNIST с нуля. Ничего сложного нет, с подобной задачи начинается любой курс по DL.

Обучил, сохранил, обрадовался — а руки-то помнят. Красивые графики лосса и точности прилагаются.

Красота

Красота

Сама модель, как я сказал выше, ничего сложного из себя не представляет. Несколько слоев свертки и MLP-голова для финальной классификации.

А вот и базовый класс сетки:

class DepthwiseSeparableConv(nn.Module):    def __init__(self, in_c, out_c, stride=1):        super().__init__()        self.use_skip = stride != 1 or in_c != out_c        if self.use_skip:            self.skip = nn.Sequential(                nn.Conv2d(in_c, out_c, 1, stride, bias=False),                nn.BatchNorm2d(out_c),            )        self.dw = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)        self.bn1 = nn.BatchNorm2d(in_c)        self.pw = nn.Conv2d(in_c, out_c, 1, bias=False)        self.bn2 = nn.BatchNorm2d(out_c)        self.relu = nn.ReLU(inplace=True)    def forward(self, x):        identity = x        x = self.relu(self.bn1(self.dw(x)))        x = self.bn2(self.pw(x))        if self.use_skip:            identity = self.skip(identity)        x = self.relu(x + identity)        return x

Все целиком:

class DepthwiseMNIST(nn.Module):    def __init__(self, num_classes=10):        super().__init__()        self.conv1 = nn.Sequential(            nn.Conv2d(1, 48, 3, padding=1, bias=False),            nn.BatchNorm2d(48),            nn.ReLU(inplace=True),        )        self.stages = nn.Sequential(            DepthwiseSeparableConv(48, 96, stride=2),            DepthwiseSeparableConv(96, 96, stride=1),            DepthwiseSeparableConv(96, 192, stride=2),            DepthwiseSeparableConv(192, 192, stride=1),        )        self.avgpool = nn.AdaptiveAvgPool2d(1)        self.fc = nn.Linear(192, num_classes)    def forward(self, x):        x = self.conv1(x)        x = self.stages(x)        x = self.avgpool(x)        x = torch.flatten(x, 1)        x = self.fc(x)        return x

Подробнее модель можно изучить в scr/train_mnist.py — тут и архитектура, и загрузка/обработка данных, и основная функция обучения, и логгирование с графиками.

Первый этап пройден, теперь задача сложнее: надо написать кастомное ядро на MLX для ускорения работы сверток.

Быстрый ликбез

Обычно для обучения нейросетей на маке используют PyTorch с бэкендом MPS (torch.device('mps')). Внутри MPS лежат готовые функции от Apple: свертки, линейные слои, реализация разных типов внимания и прочая база. Но шаг вправо, шаг влево — и PyTorch начинает собирать операцию из медленных кусков, гонять данные туда-сюда и жутко тормозить.

Чтобы обойти эти ограничения, Apple выкатила MLX — фреймворк, который дает прямой доступ к движку Metal.

Кастомные ядра в MLX пишутся на MSL (Metal Shading Language). Базируется он на стандартах С++14/С++17, но со специфическими фишками для GPU. MLX берет твой код на MSL, на лету компилирует его под видеочип и выполняет операцию с Unified Memory вообще без задержек на лишнее копирование.

Вот как выглядит простейший шаблон такого ядра на MSL, который считает операцию для каждого элемента (Elementwise):

«`

#include <metal_stdlib>using namespace metal;kernel void my_custom_op(    device const float* in [[buffer(0)]],    device float* out [[buffer(1)]],    uint index [[thread_position_in_grid]],    uint grid_size [[threads_per_grid]]) // Metal сам знает общий размер сетки{    // Простая и лаконичная проверка границ    if (index >= grid_size) return;    out[index] = in[index] * in[index];}

На выходе получаем прирост в скорости — убрали лишние абстракции и заставили железо Apple Silicon крутить нашу кастомную математику напрямую.

Ускоряем скорость

Вернемся к нашей модели. Веса лежат в output/best_model.pth, а логика ядра более-менее понятна.

Сразу стоит сказать о том, что torch.compile(model, mode="max-autotune") — возможно, но для меня это скучно. Для Depthwise Conv написаны глубокие производительные ядра, и компилятор без проблем может сделать оптимизированную версию модели. Однако у меня чисто спортивный интерес — а смогу ли я? Поэтому легкий путей искать не буду.

Однако тут-же возникает сразу несколько проблем:

  1. Разница в формате данных. В .pth-файле тензоры активаций лежат в формате NCHW (Number/Batch, Channels, Height, Width), а веса свёрток — OIHW (Output channels, Input channels, Height, Width). Но библиотека MLX ожидает данные в формате NHWC, а веса — в OHWI. Если неправильно перенести веса или невнимательно транспонировать — будет плохо. Будет очень плохо.

  2. Особенность ленивых вычислений Компилятор MLX ленив (как и я). Пока явно не вызвать mlx.core.eval() — он ничего делать не будет. Замеры времени могут показать 0.0001 секунд, однако это не реальные данные. Сама свертка еще даже не запустилась, а компилятор понял, что результат нигде не используется — значит и считать этот результат не надо.

  3. Проблема Memory-Bound и latency запуска ядер (Launch Overhead). Наша маленькая CNN-ка на 103К параметров весит копейки, а её Arithmetic Intensity на Depthwise-слоях болтается в районе 0.87-2.2 FLOPs/byte. При Ridge point чипа Apple M1 около 39 FLOPs/byte это значит, что модель намертво упирается в пропускную способность памяти (Memory bandwidth). Если считать каждый чих отдельно (Conv -> BN -> ReLU), GPU будет дольше ждать ленивую выгрузку промежуточных тензоров в глобальную память и страдать от оверхеда на запуск Metal-команд, чем реально вычислять математику.

Arithmetic Intensity (арифметическая интенсивность) — это отношение количества вычислительных операций (FLOP) к объёму прочитанных или записанных данных (байтам) в программе или алгоритме

Ridge point (точка перегиба, критическая точка) — это аппаратная характеристика вычислительной системы, определяющая минимальную арифметическую интенсивность (количество операций FLOP на 1 байт данных), которая необходима алгоритму для достижения пиковой вычислительной мощности процессора (CPU, GPU или NPU).

Присылайте такую мотивацию коллегам 3 раза в неделю и задачи будут выполняться вовремя. При необходимости провести курс мотивации еще раз.

Присылайте такую мотивацию коллегам 3 раза в неделю и задачи будут выполняться вовремя. При необходимости провести курс мотивации еще раз.

Итак, погнали! Самый важный файл тут — scr/mlx_model.py. Внутри него прописана логика ядра, которое ускорит (наверное) мою сетку. Давайте разбираться.

Внимательно транспонируем

def perm_conv(arr):    """PyTorch (C_out, C_in, H, W) -> MLX (C_out, H, W, C_in)."""    return mx.array(arr.transpose(0, 2, 3, 1))

Аккуратная работа с памятью — наше все. Не забываем про правильно транспонирование, но долго на этом моменте не задерживаем — дальше больше.

Найди меня, если сможешь

def load_pt_weights(pt_path, mlx_model):    pt = torch.load(pt_path, map_location="cpu", weights_only=True)    flat = {}    def add(key, arr, do_perm=False):        flat[key] = perm_conv(arr) if do_perm else mx.array(arr)

Мало просто поменять формат тензоров, нужно ещё заставить MLX понять, какой тензор из файла .pth к какому слою новой модели относится. В PyTorch веса лежат в плоском словаре (state_dict), а MLX строит из них красивое вложенное дерево (tree_unflatten).

Разработчики PyTorch и MLX видят мир по-разному. Из-за этого архитектура, написанная на двух фреймворках, имеет абсолютно разные имена ключей:

  1. Индексы против имен

    В PyTorch первый блок упакован в nn.Sequential. Поэтому свёртка там называется conv1.0.weight, а батч-норм — conv1.1.weight. В MLX это два раздельных явных слоя: conv1.weight и bn1.weight.

  2. Структурный хаос в циклах

    Внутри цикла по стадиям (for i in range(4)) начинается сущий ад. В PyTorch слои лежат в плоском списке stages.0.dw.weight, а в MLX они вложены в .layers, превращаясь в stages.layers.0.dw.weight.

  3. Спрятанный Skip Connection

    Самое весёлое — остаточные связи, они же скипы. В моей модели они есть только на 0 и 2 стадиях. В PyTorch это опять Sequential (skip.0.weight), а в MLX — кастомное имя skip_conv.weight.

Функция load_pt_weights решает эту проблему, перебирая названия весов и сопоставляя их.

Последние штрихи

Классы DepthwiseSeparableConv и DepthwiseMNIST создают в памяти пустую болванку модели.

Через count_params считаем параметры — должно получиться столько-же, сколько и при тренировке.

А get_model собирает весь наш конструктор воедино — читает файл с весами, правильно его разворачивает и записывает в болванку.

А теперь — хардкор на Metal

Именно ради этого мы здесь все и собрались. Всего 38 строк, но зато каких! Ладно, нагнал пафоса, теперь можно продолжать.

Шаг 1. Кто я и где я? (Индексы потоков)

uint nw = thread_position_in_grid.x;uint row = thread_position_in_grid.y;uint col = thread_position_in_grid.z;

Каждому отдельному потоку (треду) на GPU выдаются свои координаты в трехмерной сетке. Из координаты nw мы вытягиваем индекс картинки в батче n и индекс конкретного канала c (ведь у нас Depthwise-свёртка, канал обрабатывается изолированно). Координаты row и col — это пиксель (строка и столбец) на нашей выходной картинке.

Шаг 2. Виртуальное окно (Свёртка 3х3)

int h_in = (int)row * stride - 1;int w_in = (int)col * stride - 1;

Здесь мы вычисляем, куда на оригинальной входной картинке падает левый верхний угол нашего ядра свёртки 3х3. Минус один (-1) берется из-за того, что у нас padding=1, из-за чего свертка неявно заступает за границы изображения.

Шаг 3. Перемножаем и складываем (Математика)

for (int kh = 0; kh < 3; kh++) {    for (int kw = 0; kw < 3; kw++) { ... }}

Два цикла бегут по окну 3х3. Внутри стоит жесткая проверка: если мы вылезли за границы картинки (в те самые области паддинга), мы просто ничего не делаем (считаем, что там виртуальный ноль). Если мы внутри картинки, то считаем плоский индекс в памяти inp_idx (привет, channel-last формат NHWC), забираем значение пикселя, умножаем на вес из w и докидываем в общую копилку sum_val.

Шаг 4. Бесплатный BatchNorm и ReLU

float normed = (sum_val - running_mean[c]) / metal::sqrt(running_var[c] + eps);float activated = normed * gamma[c] + beta[c];activated = metal::max(0.0f, activated);

Значение sum_val всё еще лежит в сверхбыстром регистре процессора. Мы не отправляем его в медленную общую память. Мы прямо здесь вычитаем среднее, делим на корень из дисперсии, умножаем на веса батч-норма (gamma, beta) и сразу же обнуляем всё, что меньше нуля (эффект ReLU).

Шаг 5. Выгрузка

out[out_idx] = activated;

И только теперь, когда получился финальный, отмытый батч-нормом и активированный пиксель, мы один-единственный раз лезем в глобальную память устройства и записываем его туда.

Теперь, когда большая часть оптимизирована, можно запускать тесты. А что вобще измерять?

Тесты, результаты и выводы

Итак, у нас есть 3 кандидата:

  1. PyTorch MPS (эталон). Модель написана на PyTorch, работает через MPS. Это бейзлайн, без каких-либо изменений.

  2. MLX (baseline) — та же архитектура, портированная на MLX вручную (channel-last). Работает через штатные mlx.nn.Conv2d / mlx.nn.BatchNorm. Показывает, сколько MLX выдаёт без оптимизаций.

  3. MLX (fused) — то же самое, но depthwise 3x3 + BatchNorm + ReLU заменены на одно кастомное Metal-ядро (один launch вместо трёх, промежуточные тензоры не пишутся в память).

Измерял 2 основные метрики: Latency (время ожидания ответа модели на один запрос) и Throughput (кол-во запросов, которое получается обработать за определенно время).

Так-же дополнительно записал метрики и построил графики:

  • P50 (медиана) latency — основная, робастная к выбросам

  • P95 latency — tail latency, критична для SLA/real-time

  • CV (Coefficient of Variation = σ/μ) — стабильность/разброс замеров

  • Speedup factor = latency_baseline / latency_fused — во сколько раз быстрее

  • Throughput scaling efficiency = throughput(BS) / (BS × throughput(1)) — насколько хорошо масштабируется

  • FLOPs per pass — теоретическая вычислительная ёмкость (16.9M FLOPs)

  • Arithmetic intensity = FLOPs/byte — определяет характер узкого места

  • Timeline (scatter) — все 200 trials по порядку со скользящим средним. Наглядно ловит throttle

  • CDF latency — кумулятивное распределение. Сразу видно P50/P95/P99

  • Violin plot — полная форма распределения latency на каждом BS

  • Bar chart — P50 ± P95 (error bar), сравнение бэкендов рядом

  • Speedup chart — коэффициент ускорения fused vs baseline/PT по BS

  • Roofline model — achieved TFLOPS vs peak (2.66 TFLOPS / 68 GB/s). Наглядно доказывает, что модель memory-bound

Тесты проводил на разных батчах: 1, 16, 64, 128, 512.

Всего провел 3600 измерений. Анализировать каждую строчку всех таблиц смысла нет. Чистые данные лежат в output/per_trial.npz, а обработанные (сведены в удобные таблицы и сделаны выводы/предположения) — в conclusions/benchmark_result.md. В этой статье выделю самые интересные наблюдения.

Есть ощутимый прирост

Да, это самая главная новость. На BS=1 MLX-Fused быстрее конкурентов в полтора раза. Это на 50% лучше! Throughput поднялся с 339p/s до 522p/s, а latency упал с 2.947ms до 1.917ms.

Проблемы с тротлингом на большом батче

На других размерах батча цифры разняться — где-то скорость выше, где-то ниже. Отчасти это можно объяснить значениями метрики стабильности замеров.

BS

PT CV

MLX CV

Fused CV

1

4.95%

68.31%

32.34%

16

32.02%

38.13%

16.95% ⚠

64

30.84%

14.60% ⚠

33.77%

128

41.22%

8.72%

22.74%

256

19.99%

3.90%

10.10%

512

17.22%

1.87%

19.39%

MLX baseline стабилен на BS ≥ 128 (CV < 10%). Скорее всего жесткий контроль температур заставляет проц сбрасывать частоты при подозрении на аномальную активность. Если проводить измерения с остановками можно получить немного другие цифры. Пока идея в бэклоге.

Идеи и следующие шаги

Делал я это все не просто так. Если получится оптимизировать CNN, то можно переходить на более сложный уровень — LLM со сверточными слоями. Там тоже непаханное поле и много идеи — все надо проверить и применить.

Пока я продолжаю работу со сверточной сеткой, желая добиться ЕЩЕ более хороших результатов. Идеи для этого такие:

  1. Использовать float16 как основной тип данных (надо будет подумать над правильной реализацией BN)

  2. Экспортировать веса в .safetensors и не мучиться с питонячьим оверхедом.

  3. Зафьюзить всю сеть в один kernel launch — это может дать еще больший прирост.

Актуальный код проекта можно посмотреть ТУТ.
Если есть идеи как улучшить/ускорить мои методы — пишите комментарии или создавайте issues на гите. Буду рад любой обратной связи.

Спасибо за внимание,
morginalium8

ссылка на оригинал статьи https://habr.com/ru/articles/1052790/