Тернарный KAN: не баг, а фича — почему дискретные веса работают лучше

от автора


Вступление: зачем лезть в KAN

Это продолжение поста “Две нейросети по 15 КБ” — там были базовые цифры. А тут уже личная история: как делалось, что пошло не так, и что выяснилось по пути.

Май 2024 года. Выходит статья “KAN: Kolmogorov-Arnold Networks”. И происходит то, что бывает раз в несколько лет — кто-то предлагает альтернативу MLP.

Не модификацию и не лайфхак — альтернативу.

В MLP каждый нейрон делает weight × input + bias, и все 80 лет развития — это вариации на тему “как сделать этот вес точнее, быстрее, разреженнее”. KAN предлагает другое: заменить линейный вес на обучаемую функцию. Вроде мелкий трюк, а на практике — меньше параметров при той же точности и встроенная интерпретируемость.

К 2026 году уже появились QuantKAN (4-битное квантование), KANtize (2-3-битные B-spline таблицы), BiKA (аппаратный акселератор, вдохновленный KAN). И все они, по сути, про одно — сделать KAN меньше, чтоб работал не только на GPU.

Граница в три бита — она же психологическая. Ниже 4 бит у всех начинается «а вдруг всё сломается». И знаете что? Обычно так и есть. Любой, кто квантовал нейросети в 2 бита, знает: точность падает. Не чуть-чуть — катастрофически.

Но {-1, 0, +1} — это даже не два бита, это log₂(3) ≈ 1.58 бита. Формально — между binary и ternary, а по ощущениям — чистое безумие.

Ну я и решил попробовать.


Как я сломал правило “квантование убивает точность”

Эксперимент выглядел так: берем GraphKAN — мою графовую реализацию KAN, где нейроны соединены не слоями, а произвольными направленными связями. Обучаем в float, потом переводим веса в {-1, 0, +1} через Straight-Through Estimator. Потом — hard clamp. Потом — дообучение только scale и bias.

Четыре фазы, никакой магии.

Ожидал увидеть падение точности на 5-10 пунктов. До 85-88%. Ну норм: 15 КБ модель, которая хоть что-то распознаёт на MNIST, — уже результат.

Но на пятой эпохе float я получил 94.77%. Ничего выдающегося для полносвязной сети, но KAN есть KAN.

Дальше — STE ternary, и тут счётчик показал 95.78%.

Я перепроверил — и ещё раз. Hard clamp дал 96.09%, finetune — 96.15%.

Четыре фазы, точность растет на каждом шагу. Float → STE → clamp → finetune — и вместо потери точности я получил прирост в +1.38 процентных пункта.

Я полез гуглить, есть ли у кого-то такое. QuantKAN, KANtize, BiKA — у всех «negligible accuracy loss». То есть просто «ну почти не потеряли». А у меня ведь — выигрыш. Среди статей по KAN я такого не встречал.

Я перепроверил три раза на разных seed. Эффект стабильный.


Почему так вышло — гипотеза

Тернарный вес — это вентиль: пропустить сигнал как есть (+1), инвертировать (-1) или заблокировать (0). Никаких тонких настроек, никаких 0.0037, которые вносят шум, но формально считаются “информацией”.

Когда режешь вес до трех значений, ты режешь не только точность веса — ты режешь пространство гипотез. Модель не может выучить шумовые корреляции, потому что для этого нужны точные значения. В тернарном формате шум фильтруется округлением на этапе backward — мелкие градиенты просто не пробивают порог.

Короче говоря, эффект такой: веса, зажатые тремя значениями, сами работают как регуляризатор. Дискретность не даёт модели переобучаться — банально не хватает «разрешения» веса, чтобы запомнить шум. В литературе по KAN я такого не встречал.

На практике процесс split на четыре шага. Сначала тупо ограничиваем веса диапазоном [-1, 1] — чтобы они не разлетелись. Потом подключаем STE: на прямом проходе веса уже тернарные, а градиенты считаем как будто ничего не изменилось. Дальше — принудительная фиксация: каждый вес становится ровно -1, 0 или +1. И финальный шаг — дообучаем только масштаб и смещение, веса заморожены.

Каждая фаза агрессивнее предыдущей — и каждая даёт прирост. Если бы мне кто-то рассказал такой результат до эксперимента, я бы не поверил.


Что получилось: таблица

Модель

Веса

Размер

MNIST

Fashion-MNIST

GraphKAN 256→100→10

float

~15 КБ

94.77%

84.1%

GraphKAN 256→100→10

{-1,0,+1}

~15 КБ

96.15%

86.68%

MLP 256→100→10

float

~107 КБ

~93%

MLP в 7 раз больше по размеру — и на 3 пункта хуже по точности. Я не про то, что мы победили SOTA — просто KAN с тернарными весами оказался эффективнее обычного MLP при тех же нейронах и в 7 раз меньшем объеме.


Пять доменов

Проверял на всем, что подворачивалось под руку:

  • MNIST — 96.15%. Цифры, классика.

  • Fashion-MNIST — 86.68%. Одежда, сложнее цифр — но тернарная версия все равно бьет float.

  • HAR (Human Activity Recognition) — акселерометры с телефона. Тернарный KAN работает на временных рядах без RNN.

  • FSDD (Free Spoken Digit Dataset) — аудио, цифры голосом. Распознает без DSP-предобработки.

  • CIFAR-10 — тут честно: 47.83% на 8×8 входе. KAN с такой архитектурой не тянет сложные картинки как CNN — для 32×32 входов нужно больше параметров или сверточная структура.

На всех доменах, кроме CIFAR, тернарная версия по точности либо совпадает с float, либо бьет его.


ELM-режим: 99.3% от full BP accuracy без обучения скрытого слоя

Еще штука, которую я не упомянул в прошлый раз — GraphKAN умеет работать в режиме Extreme Learning Machine.

Случайная инициализация скрытого слоя → заморозка → решение только выходного веса через Least Squares. Никакого backprop через скрытые нейроны. На Fashion-MNIST это дало 78.7% при скрытом слое 500 — то есть 99.3% от accuracy полного backprop (79.2%).

Зачем это нужно? Ну типа ELM-режим означает, что piecewise-linear структура KAN генерирует естественно дискриминативные признаки даже из случайных весов. Для TinyML это критично: можно загрузить модель, заморозить 99% параметров и дообучать только последний слой — прямо на устройстве.

Размер скрытого слоя

Accuracy

% от BP

Размер модели

H=50

74.5%

94.1%

3.7 КБ

H=100

77.0%

97.3%

6.8 КБ

H=200

77.7%

98.1%

14 КБ

H=500

78.7%

99.3%

31.6 КБ


Это работает не только на KAN

Я проверил QAT-пайплайн на обычной CNN (Conv2d + ReLU + MaxPool + FC). Те же 4 фазы, та же gamma absorption. Результат:

Конфигурация

Float

STE ternary

True ternary

CNN (Fashion-MNIST)

91.57%

91.83%

92.02%

График сжатия

1x

8x

16x

Тернарная CNN не просто сохранила точность — она ее повысила на 0.45 п.п. То есть эффект регуляризации квантованием работает и на других архитектурах, не только на KAN. Не сошлось на одной конфигурации.


Сравнение с тем, что было до

До нас KAN квантовали вот как:

Работа

Биты

Метод

Результат

QuantKAN

4-bit

QAT для KAN

loss, не gain

KANtize

2-3 bit

B-spline таблицы

loss, не веса

BiKA

аппаратный

компараторы

не KAN, HW

BitNet b1.58

1.58 bit

ternary Transformer

lossless, но Transformer

Мы — GraphKAN

1.58 bit

STE + gamma abs + clamp

+1.4% к float

И бонус: тот же pipeline дает +0.45% на CNN. Похоже, эффект регуляризации квантованием — это скорее не особенность KAN, а общее свойство ternary QAT с gamma absorption.


Куда это помещается: поход по железу

15 КБ — и проблема выбора чипа просто исчезает.

Устройство

Flash

L1 кэш

FPU

GraphKAN влезает?

ARM Cortex-M0+ (STM32G0, $0.50)

16-64 КБ

нет

✅ целиком

ARM Cortex-M4 (STM32F4)

512 КБ

16 КБ

есть

✅ в L1

ARM Cortex-M7 (STM32H7)

2 МБ

64 КБ

есть

✅ в L1

RISC-V (GD32V)

32 КБ

нет

ESP32-S3

384 КБ

16 КБ

нет

✅ в L1

Arduino Nano RP2040

264 КБ

нет

И главное: тернарные веса заменяют умножение на условное сложение/вычитание/пропуск — два-три такта против 15-20 на float32. Никакого FPU, никаких DSP-инструкций — нафиг не нужно.


Где это применить

TinyML. Датчик вибрации на STM32G0 за $0.50 с батарейкой CR2032 на 5 лет. GraphKAN предсказывает поломку подшипника. Без Wi-Fi, без облака — всё на месте.

Умные часы. Анализ ЭКГ прямо на чипе, без Bluetooth на телефон. Медицинские данные не покидают устройство.

LoRa-сенсоры. 15 КБ — размер, который можно передать по радио за пару секунд. Ардуино с камерой — снимок, GraphKAN, результат. Все локально.

Edge AI без GPU. Графовая архитектура позволяет добавить нейрон на новый сенсор без перестройки сети — для IoT с разными конфигурациями датчиков это gold.

Образование. 15 КБ — модель, которую можно разобрать побайтово. Понять, как работает нейросеть, не имея GPU.


Ограничения (честно)

Не хочу, чтобы пост выглядел как “мы всех победили, patent pending”. Давайте честно.

CIFAR-10: 47.83%. KAN в своей базовой форме не тянет сложные изображения. Для 32×32 цветных картинок нужна сверточная архитектура. Вот тут фундаментальная проблема: piecewise-linear функции хуже масштабируются на высокоразмерные входы, чем свертки.

MNIST 96.15% — это не SOTA. SOTA на MNIST — 99.8%+ (Ensemble CNN). Но SOTA требует ~10+ МБ модели и GPU для инференса. 96.15% при 15 КБ — это другой класс задач: TinyML, где нет мегабайт и нет GPU. Сравнивать эти цифры напрямую — как сравнивать скорость Formula-1 и вездехода. Разные трассы.

Граф собирается вручную. Топология (кто с кем соединён) не обучается — фиксируется до обучения. Тут нет обучения adjacency matrix, как в graph neural network. Граф, но структура фиксирована до обучения.

Умножение все равно нужно на этапе training. Тернарность — только для inference. Training идет в float с STE.


Эпилог: почему 15 КБ важнее, чем 99% accuracy

Логичный вопрос: а почему бы просто не взять CNN и не сжать ее до 15 КБ?

Ответ: попробуйте.

Возьмите ResNet-18 (44 MB), квантизуйте до 2 бит — получите ~2.75 MB. Не 15 КБ. Mobilenet (4 MB) — ~500 КБ. Все равно не 15 КБ.

Чтобы получить 15 КБ, нужно либо резать архитектуру до неприличия (один слой, 10 нейронов — и 60% accuracy), либо менять саму парадигму.

В общем, GraphKAN меняет парадигму: вес — это не число для умножения, а один из трех символов. Пропустить, инвертировать или заблокировать — и всё.

Это уже не совсем нейросеть в привычном смысле. По духу оно ближе к конечному автомату с обучаемыми переходами, где каждый переход — {-1, 0, +1}. И этот автомат размером с три JPEG-фотографии угадывает цифры с точностью 96%.

Дело не в рекордах. SOTA будет обновляться каждые полгода — да и пофиг. А 15 КБ — это порог, после которого модель уже не «большая нейросеть», а библиотечная функция.

Нейросеть не загружается на микроконтроллер отдельно — она компилируется вместе с прошивкой. Одна функция на одну структуру — и никаких тебе DMA, внешних flash или драйверов.

15 КБ — это размер не модели.

Это размер решения.


Ссылки:

ссылка на оригинал статьи https://habr.com/ru/articles/1049822/