
Привет, чемпионы!
Представьте, что у вас есть большой и сложный проект, и вы наняли двух управленцев: Кабан-Кабаныча и Руководителева. Вы даете им одинаковую задачу: набрать штат сотрудников и выполнить ваш проект. Вся прибыль вместе с начальным бюджетом останется у них.
Кабан-Кабаныч решил, что нет смысла платить отдельным специалистам по DevOps, backend, ML и другим направлениям, и нанял всего одного сотрудника за 80 монеток. Этот бедняга работал в стиле «один за всех» и, естественно, быстро выгорел и «умер». Кабан-Кабаныч, не долго думая, нанял еще одного такого же сотрудника. В итоге вы вернулись и увидели печальную картину: задачу никто не решил, остался лишь Кабан-Кабаныч и кладбище несчастных сотрудников.

А вот Руководителев поступил иначе: он распределил бюджет на несколько похожих сотрудников, но сначала не понимал, кто из них в чём лучше. Тогда он стал давать им небольшие задачи и внимательно наблюдать за результатами. Через некоторое время он понял, что сотрудник №1 на 70% лучше справляется с задачами по ML, сотрудник №2 на 80% эффективнее в backend-разработке и так далее. Так Руководителев постепенно сформировал команду экспертов, сам став управляющим (или «gating»-узлом), который распределяет задачи на основе знаний о возможностях каждого сотрудника. Сотрудники углубляли экспертизу в своих направлениях, а Руководителев становился всё эффективнее в распределении задач.
Внезапно мы пришли к интересному решению:
-
Руководителев — это
gating network, который распределяет задачи, исходя из предыдущих успехов сотрудников. -
Сотрудники — это
local experts, каждый из которых специализируется на своей части задач.

Таким образом, мы экономим ресурсы, получаем сильных специалистов и достигаем отличных результатов за короткое время.
Именно так в 1991 году и появилось решение Adaptive Mixtures of local Experts
Этот подход доказал эффективность, сокращая время обучения моделей почти вдвое.
Как работает MoE?
Представьте модель, у которой есть входные и выходные данные, а между ними набор экспертов. Этих экспертов организует управляющая сеть (gating network), определяющая, какие эксперты могут лучше справиться с конкретной задачей. Gating-сеть, которая присваивает веса результату каждого эксперта, объединяя их в итоговый ответ.
Звучит красиво, но не всё так просто… Во время обучения возникают интересные и даже «ломающие мозг» ситуации, особенно когда осознаёшь, что созданная тобой модель может «вынести» тебя самого.
Conditional Computation одна из фишек MoE: возможность отключать или частично использовать экспертов. Это позволяет комбинировать разные архитектуры, каждая из которых выявляет уникальные паттерны в данных. Модель становится гибкой: сама решает, каких экспертов задействовать активно, кого игнорировать, а кого подключить чуть-чуть.
Ключевая особенность — разреженность. С помощью MoE можно масштабировать модель без пропорционального увеличения вычислительной нагрузки. Это очень важно, ведь позволяет обучать огромное количество экспертов, используя при этом только нужных. В этом нам помогает важный гиперпараметр — top_k, определяющий, сколько лучших экспертов будет выбрано для каждого входа.

Но основные сложности начинаются с настройки гиперпараметров и архитектурных решений. Самая большая проблема MoE — это «прилипание гейта», когда маршрутизатор начинает постоянно выбирать одних и тех же экспертов. Эти избранные эксперты получают больше данных и быстрее обучаются, в то время как остальные «скучают и пьют кофе».
Возникает закономерный вопрос: зачем тогда вообще нужны остальные эксперты?
Как с этим бороться? В своём коде я добавил трекер распределения данных по экспертам, чтобы контролировать, не «залип» ли гейт. Также я внедрил несколько хитрых решений, подсмотренных на профессиональных форумах.
Давайте кратко резюмируем:
Технология MoE выгодна за счёт разреженности и гибкости использования экспертов. Однако это «сделка с дьяволом», поскольку возникают сложности:
-
Сложная балансировка работы экспертов.
-
Функция потерь должна учитывать как производительность экспертов, так и маршрутизатора.
-
Количество гиперпараметров (количество экспертов, архитектура gating-сети) усложняет настройку модели.
Где сейчас используют MoE?
Почти все современные LLM используют MoE. Например, недавно вышедшая модель Llama4 Scout с 16x17B параметрами — это 16 экспертов по 17 миллиардов параметров каждый. То есть на инференсе вы используете не все 272 млрд параметров, а только top_k выбранных. Впечатляющее снижение вычислительных затрат, правда?
Также технология активно применяется в компьютерном зрении, и сейчас мы её протестируем на простом примере V-MoEs.
Тест драйв технологии
Итак, для обучения возьмем простенький датасет CIFAR100 и обучим на нем нашу кастомную V-MoEs для классификации изображений.
Сама по себе архитектура будет состоять из следующего:
Классический VIT, но ее часть классификатора мы обернем в decoder блок, где у нас будет применена MOE

Начнем с маршрутизатора, в нашем случае он был реализован следующим образом
import torch import torch.nn as nn import torch.nn.functional as F class GatingNetwork(nn.Module): def __init__(self, input_dim = 151296, num_experts=4, top_k=2, use_noise=True, noise_std=1e-2, temperature=1.0): super().__init__() self.num_experts = num_experts self.top_k = top_k self.use_noise = use_noise self.noise_std = noise_std self.temperature = temperature self.gate = nn.Sequential( nn.Linear(input_dim, 512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, num_experts) ) def forward(self, x): logits = self.gate(x) # (B, num_experts) if self.use_noise and self.training: scale = logits.std(dim=1, keepdim=True).clamp(min=1e-3) noise = torch.randn_like(logits) * self.noise_std * scale logits = logits + noise topk_vals, topk_indices = torch.topk(logits, self.top_k, dim=1) gates = F.softmax(topk_vals / self.temperature, dim=1) # (B, top_k) return topk_indices, gates
Он берёт входной вектор x, оценивает, какие эксперты из num_experts лучше подойдут для каждого примера в батче, и возвращает top_k лучших экспертов с их весами.
То есть это — классическая Gating Network, которая решает, каким экспертам дать поработать с входом.
Обратите внимание, что тут есть noisy gating — это один из способов избежать «залипания гейта» на одном и том же эксперте. Во время тренировки шум масштабируется и в зависимости от поставленной нами пропорции влияет на решение о том какого эксперта повыбирать. Иными словами мы влияем на «результатова», чтобы он давал шансы большему числу экспертов, а не выбирал любимчиков.
Создадим экспертов
import torch.nn as nn class FFNExpert(nn.Module): def __init__(self, input_dim, hidden_dims, output_dim, dropout_prob=0.5): super(FFNExpert, self).__init__() layers = [] self.linears = nn.ModuleList() prev_dim = input_dim for hidden_dim in hidden_dims: linear = nn.Linear(prev_dim, hidden_dim) self.linears.append(linear) layers.append(linear) layers.append(nn.LayerNorm(hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout_prob)) prev_dim = hidden_dim final_linear = nn.Linear(prev_dim, output_dim) self.linears.append(final_linear) layers.append(final_linear) self.network = nn.Sequential(*layers) self._initialize_weights() def _initialize_weights(self): for linear in self.linears: nn.init.xavier_uniform_(linear.weight) if linear.bias is not None: nn.init.zeros_(linear.bias) def forward(self, x): return self.network(x) class FFNExpertSmall(FFNExpert): def __init__(self, input_dim, output_dim): super(FFNExpertSmall, self).__init__(input_dim, hidden_dims=[256, 128], output_dim=output_dim, dropout_prob=0.3) class FFNExpertMedium(FFNExpert): def __init__(self, input_dim, output_dim): super(FFNExpertMedium, self).__init__(input_dim, hidden_dims=[512, 256, 128], output_dim=output_dim, dropout_prob=0.4) class FFNExpertLarge(FFNExpert): def __init__(self, input_dim, output_dim): super(FFNExpertLarge, self).__init__(input_dim, hidden_dims=[1024, 512, 256, 128], output_dim=output_dim, dropout_prob=0.5) class FFNExpertVeryLarge(FFNExpert): def __init__(self, input_dim, output_dim): super(FFNExpertVeryLarge, self).__init__(input_dim, hidden_dims=[2048, 1024, 512, 256, 128], output_dim=output_dim, dropout_prob=0.6)
Тут в целом все просто, мы набросали 4 эксперта с разными параметрами и посмотрим на то как они будут обучаться.
Начнем собирать модель
import torch.nn as nn import timm class ViT_backbone(nn.Module): def __init__(self): super().__init__() self.backbone = timm.create_model('vit_base_patch16_224', pretrained=True) for param in self.backbone.parameters(): param.requires_grad = False self.embed_dim = self.backbone.head.in_features self.backbone.reset_classifier(0) self.ln = nn.LayerNorm(self.embed_dim) self.ln2 = nn.LayerNorm(self.embed_dim) self.attn = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=8, batch_first=True) def forward(self, x): skip = self.backbone.forward_features(x) # [B, N, D] x_ln = self.ln(skip) attn_out, _ = self.attn(x_ln, x_ln, x_ln) x_attn = attn_out + skip x_final = self.ln2(x_attn).flatten(1) # [B, N*D] return x_final
Тут все просто возьмем классическую модель VIT и добавим к ней слои нормализации после Multihead Attention и skip connection.
После сделаем наше объединение и наконец-то MOE
import torch import torch.nn as nn import torch.nn.functional as F from model.gating_network import GatingNetwork from model.Vit_model import ViT_backbone class MoECNN(nn.Module): def __init__(self, experts, input_for_gating = 151296, top_k=2, output_dim=100, use_aux_loss=True, aux_loss_weight=0.01, warmup_iters=500, noise_std = 0.5): super().__init__() self.num_experts = len(experts) self.top_k = top_k self.output_dim = output_dim self.use_aux_loss = use_aux_loss self.aux_loss_weight = aux_loss_weight self.warmup_iters = warmup_iters self.iter = 0 self.backbone = ViT_backbone() self.experts = nn.ModuleList(experts) self.gating = GatingNetwork( input_dim = input_for_gating, num_experts=self.num_experts, top_k = top_k, noise_std=noise_std) self.register_buffer("expert_usage", torch.zeros(self.num_experts)) def forward(self, x): batch_size = x.size(0) device = x.device x = self.backbone(x) if self.training and self.iter < self.warmup_iters: random_indices = torch.randint(0, self.num_experts, (batch_size, self.top_k), device=device) gates = torch.full((batch_size, self.top_k), 1.0 / self.top_k, device=device) topk_indices = random_indices self.iter += 1 else: topk_indices, gates = self.gating(x) output = torch.zeros(batch_size, self.output_dim, device=device) self.expert_usage.zero_() for i in range(self.top_k): idx = topk_indices[:, i] for expert_idx in torch.unique(idx): expert_mask = (idx == expert_idx) if expert_mask.sum() == 0: continue x_sel = x[expert_mask] y_sel = self.experts[expert_idx](x_sel) gate_weight = gates[expert_mask, i].unsqueeze(1) output[expert_mask] += gate_weight * y_sel self.expert_usage[expert_idx] += expert_mask.sum() aux_loss = None if self.use_aux_loss and self.training: usage = self.expert_usage / batch_size aux_loss = ((usage - usage.mean()) ** 2).mean() * self.aux_loss_weight return output, aux_loss
Первое на что обратим внимание это это warmup_iters. Тут у нас это число итераций где мы как-бы отключаем gating-сеть , чтобы избежать коллапса распределения (один эксперт выбирается чаще остальных до того, как сеть обучится разумно маршрутизировать входы). Это дает нам «разогреть» экспертов передавая им равномерно данные и далее мы уже начинаем более тонко избирать экспертов за счет gating network.
Второй момент это добавление use_aux_loss. Данный параметр позволяет нам учитывать в общем лоссе неравномерное распределение по экспертам в общий loss.
Как итог модель выбирает tok_k экспертов и на основе их предсказаний делает взвешенную сумму, после чего выдает результат и loss по распределению.
Что в итоге?
При простом «на коленке» мы смогли получить f1 на тесте 89.%. Более явно поиграв с гипперпараметрами, типами экспертов и некоторыми изощренностями думаю, что можно получить результат лучше. Самое главное, что

Давайте теперь проведем модель через один батч и посмотрим, что там на одном батче, что произошло по графику и посмотрим на первые 10 сэмплов батча.

Как можем увидеть, у нас 2 эксперт оказался в данной итерации не востребован, а использовали мы с 0,1,3 эксперта в разной пропорции.
Можно сказать, что вот: «второй эксперт переобучился или обучился плохо». Однако давайте глянем глубже! Мы ведь отслеживаем все через clearml 🙂

На тестовом датасете в среднем можно заметить, что все эксперты примерно вышли на какую-то свою зону ответственности. Хотя конечно от первого эксперта хотелось ожидать побольше!
Теперь давайте посмотрим на визуализацию результатов:

Несмотря на шакальность(мы работаем с CIFAR100 напоминаю) мы получили весьма неплохие результаты. И теперь вишенка на торте — это отслеживание по экспертам. Их собственно говоря мы итак логируем и сейчас можем провести на маленьком сэмпле аналитику. Если у нас есть очень большой эксперт и он не пригодился в использовании в вычислениях, то мы можем сэкономить очень много памяти.

Подводя итоги основным концептом было показать какие проблемы бывают и сложности при работе с технологией, а также ее возможности и потенциал, который уже сейчас очень успешно реализуется!
🔥 Ставьте лайк и напишите какие темы было бы интересно разобрать дальше! Самое главное — пробуйте и экспериментируйте!
✔️ Присоединяйтесь к нашему Telegram-сообществу @datafeeling, чтобы первыми применять на практике передовые технологии!
ссылка на оригинал статьи https://habr.com/ru/articles/902728/
Добавить комментарий