Mixture of Experts: когда нейросеть учится делегировать

от автора

Привет, чемпионы!

Представьте, что у вас есть большой и сложный проект, и вы наняли двух управленцев: Кабан-Кабаныча и Руководителева. Вы даете им одинаковую задачу: набрать штат сотрудников и выполнить ваш проект. Вся прибыль вместе с начальным бюджетом останется у них.

Кабан-Кабаныч решил, что нет смысла платить отдельным специалистам по DevOps, backend, ML и другим направлениям, и нанял всего одного сотрудника за 80 монеток. Этот бедняга работал в стиле «один за всех» и, естественно, быстро выгорел и «умер». Кабан-Кабаныч, не долго думая, нанял еще одного такого же сотрудника. В итоге вы вернулись и увидели печальную картину: задачу никто не решил, остался лишь Кабан-Кабаныч и кладбище несчастных сотрудников.

А вот Руководителев поступил иначе: он распределил бюджет на несколько похожих сотрудников, но сначала не понимал, кто из них в чём лучше. Тогда он стал давать им небольшие задачи и внимательно наблюдать за результатами. Через некоторое время он понял, что сотрудник №1 на 70% лучше справляется с задачами по ML, сотрудник №2 на 80% эффективнее в backend-разработке и так далее. Так Руководителев постепенно сформировал команду экспертов, сам став управляющим (или «gating»-узлом), который распределяет задачи на основе знаний о возможностях каждого сотрудника. Сотрудники углубляли экспертизу в своих направлениях, а Руководителев становился всё эффективнее в распределении задач.

Внезапно мы пришли к интересному решению:

  • Руководителев — это gating network, который распределяет задачи, исходя из предыдущих успехов сотрудников.

  • Сотрудники — это local experts, каждый из которых специализируется на своей части задач.

Таким образом, мы экономим ресурсы, получаем сильных специалистов и достигаем отличных результатов за короткое время.

Именно так в 1991 году и появилось решение Adaptive Mixtures of local Experts

Этот подход доказал эффективность, сокращая время обучения моделей почти вдвое.

Как работает MoE?

Представьте модель, у которой есть входные и выходные данные, а между ними набор экспертов. Этих экспертов организует управляющая сеть (gating network), определяющая, какие эксперты могут лучше справиться с конкретной задачей. Gating-сеть, которая присваивает веса результату каждого эксперта, объединяя их в итоговый ответ.

Звучит красиво, но не всё так просто… Во время обучения возникают интересные и даже «ломающие мозг» ситуации, особенно когда осознаёшь, что созданная тобой модель может «вынести» тебя самого.

Conditional Computation одна из фишек MoE: возможность отключать или частично использовать экспертов. Это позволяет комбинировать разные архитектуры, каждая из которых выявляет уникальные паттерны в данных. Модель становится гибкой: сама решает, каких экспертов задействовать активно, кого игнорировать, а кого подключить чуть-чуть.

Ключевая особенность — разреженность. С помощью MoE можно масштабировать модель без пропорционального увеличения вычислительной нагрузки. Это очень важно, ведь позволяет обучать огромное количество экспертов, используя при этом только нужных. В этом нам помогает важный гиперпараметр — top_k, определяющий, сколько лучших экспертов будет выбрано для каждого входа.

Но основные сложности начинаются с настройки гиперпараметров и архитектурных решений. Самая большая проблема MoE — это «прилипание гейта», когда маршрутизатор начинает постоянно выбирать одних и тех же экспертов. Эти избранные эксперты получают больше данных и быстрее обучаются, в то время как остальные «скучают и пьют кофе».

Возникает закономерный вопрос: зачем тогда вообще нужны остальные эксперты?

Как с этим бороться? В своём коде я добавил трекер распределения данных по экспертам, чтобы контролировать, не «залип» ли гейт. Также я внедрил несколько хитрых решений, подсмотренных на профессиональных форумах.

Давайте кратко резюмируем:

Технология MoE выгодна за счёт разреженности и гибкости использования экспертов. Однако это «сделка с дьяволом», поскольку возникают сложности:

  • Сложная балансировка работы экспертов.

  • Функция потерь должна учитывать как производительность экспертов, так и маршрутизатора.

  • Количество гиперпараметров (количество экспертов, архитектура gating-сети) усложняет настройку модели.

Где сейчас используют MoE?

Почти все современные LLM используют MoE. Например, недавно вышедшая модель Llama4 Scout с 16x17B параметрами — это 16 экспертов по 17 миллиардов параметров каждый. То есть на инференсе вы используете не все 272 млрд параметров, а только top_k выбранных. Впечатляющее снижение вычислительных затрат, правда?

Также технология активно применяется в компьютерном зрении, и сейчас мы её протестируем на простом примере V-MoEs.

Тест драйв технологии

Итак, для обучения возьмем простенький датасет CIFAR100 и обучим на нем нашу кастомную V-MoEs для классификации изображений.

Сама по себе архитектура будет состоять из следующего:

Классический VIT, но ее часть классификатора мы обернем в decoder блок, где у нас будет применена MOE

Начнем с маршрутизатора, в нашем случае он был реализован следующим образом

import torch import torch.nn as nn import torch.nn.functional as F  class GatingNetwork(nn.Module):     def __init__(self,                  input_dim = 151296,                  num_experts=4,                  top_k=2,                  use_noise=True,                  noise_std=1e-2,                  temperature=1.0):         super().__init__()          self.num_experts = num_experts         self.top_k = top_k         self.use_noise = use_noise         self.noise_std = noise_std         self.temperature = temperature          self.gate = nn.Sequential(             nn.Linear(input_dim, 512),             nn.ReLU(),             nn.Dropout(0.2),             nn.Linear(512, 256),             nn.ReLU(),             nn.Dropout(0.2),             nn.Linear(256, num_experts)         )      def forward(self, x):         logits = self.gate(x)  # (B, num_experts)          if self.use_noise and self.training:             scale = logits.std(dim=1, keepdim=True).clamp(min=1e-3)               noise = torch.randn_like(logits) * self.noise_std * scale             logits = logits + noise          topk_vals, topk_indices = torch.topk(logits, self.top_k, dim=1)          gates = F.softmax(topk_vals / self.temperature, dim=1)  # (B, top_k)          return topk_indices, gates 

Он берёт входной вектор x, оценивает, какие эксперты из num_experts лучше подойдут для каждого примера в батче, и возвращает top_k лучших экспертов с их весами.

То есть это — классическая Gating Network, которая решает, каким экспертам дать поработать с входом.

Обратите внимание, что тут есть noisy gating — это один из способов избежать «залипания гейта» на одном и том же эксперте. Во время тренировки шум масштабируется и в зависимости от поставленной нами пропорции влияет на решение о том какого эксперта повыбирать. Иными словами мы влияем на «результатова», чтобы он давал шансы большему числу экспертов, а не выбирал любимчиков.

Создадим экспертов

import torch.nn as nn   class FFNExpert(nn.Module):     def __init__(self, input_dim, hidden_dims, output_dim, dropout_prob=0.5):         super(FFNExpert, self).__init__()          layers = []         self.linears = nn.ModuleList()          prev_dim = input_dim         for hidden_dim in hidden_dims:             linear = nn.Linear(prev_dim, hidden_dim)             self.linears.append(linear)             layers.append(linear)             layers.append(nn.LayerNorm(hidden_dim))             layers.append(nn.ReLU())             layers.append(nn.Dropout(dropout_prob))             prev_dim = hidden_dim          final_linear = nn.Linear(prev_dim, output_dim)         self.linears.append(final_linear)         layers.append(final_linear)          self.network = nn.Sequential(*layers)         self._initialize_weights()      def _initialize_weights(self):         for linear in self.linears:             nn.init.xavier_uniform_(linear.weight)             if linear.bias is not None:                 nn.init.zeros_(linear.bias)      def forward(self, x):         return self.network(x)   class FFNExpertSmall(FFNExpert):     def __init__(self, input_dim, output_dim):         super(FFNExpertSmall, self).__init__(input_dim, hidden_dims=[256, 128], output_dim=output_dim, dropout_prob=0.3)   class FFNExpertMedium(FFNExpert):     def __init__(self, input_dim, output_dim):         super(FFNExpertMedium, self).__init__(input_dim, hidden_dims=[512, 256, 128], output_dim=output_dim,                                               dropout_prob=0.4)   class FFNExpertLarge(FFNExpert):     def __init__(self, input_dim, output_dim):         super(FFNExpertLarge, self).__init__(input_dim, hidden_dims=[1024, 512, 256, 128], output_dim=output_dim,                                              dropout_prob=0.5)   class FFNExpertVeryLarge(FFNExpert):     def __init__(self, input_dim, output_dim):         super(FFNExpertVeryLarge, self).__init__(input_dim, hidden_dims=[2048, 1024, 512, 256, 128],                                                  output_dim=output_dim, dropout_prob=0.6) 

Тут в целом все просто, мы набросали 4 эксперта с разными параметрами и посмотрим на то как они будут обучаться.

Начнем собирать модель

import torch.nn as nn import timm  class ViT_backbone(nn.Module):     def __init__(self):         super().__init__()         self.backbone = timm.create_model('vit_base_patch16_224',                                           pretrained=True)          for param in self.backbone.parameters():             param.requires_grad = False          self.embed_dim = self.backbone.head.in_features          self.backbone.reset_classifier(0)          self.ln = nn.LayerNorm(self.embed_dim)         self.ln2 = nn.LayerNorm(self.embed_dim)         self.attn = nn.MultiheadAttention(embed_dim=self.embed_dim,                                           num_heads=8,                                           batch_first=True)      def forward(self, x):         skip = self.backbone.forward_features(x)  # [B, N, D]         x_ln = self.ln(skip)         attn_out, _ = self.attn(x_ln, x_ln, x_ln)         x_attn = attn_out + skip         x_final = self.ln2(x_attn).flatten(1)  # [B, N*D]         return x_final

Тут все просто возьмем классическую модель VIT и добавим к ней слои нормализации после Multihead Attention и skip connection.

После сделаем наше объединение и наконец-то MOE

import torch import torch.nn as nn import torch.nn.functional as F from model.gating_network import GatingNetwork from model.Vit_model import ViT_backbone  class MoECNN(nn.Module):     def __init__(self,                  experts,                  input_for_gating = 151296,                  top_k=2,                  output_dim=100,                  use_aux_loss=True,                  aux_loss_weight=0.01,                  warmup_iters=500,                  noise_std = 0.5):         super().__init__()         self.num_experts = len(experts)         self.top_k = top_k         self.output_dim = output_dim         self.use_aux_loss = use_aux_loss         self.aux_loss_weight = aux_loss_weight         self.warmup_iters = warmup_iters         self.iter = 0          self.backbone = ViT_backbone()         self.experts = nn.ModuleList(experts)         self.gating = GatingNetwork( input_dim = input_for_gating,                                      num_experts=self.num_experts,                                      top_k = top_k,                                      noise_std=noise_std)          self.register_buffer("expert_usage",                              torch.zeros(self.num_experts))      def forward(self, x):         batch_size = x.size(0)         device = x.device         x = self.backbone(x)          if self.training and self.iter < self.warmup_iters:             random_indices = torch.randint(0,                                            self.num_experts,                                            (batch_size, self.top_k),                                            device=device)             gates = torch.full((batch_size, self.top_k),                                1.0 / self.top_k,                                device=device)             topk_indices = random_indices             self.iter += 1         else:             topk_indices, gates = self.gating(x)          output = torch.zeros(batch_size, self.output_dim, device=device)         self.expert_usage.zero_()          for i in range(self.top_k):             idx = topk_indices[:, i]             for expert_idx in torch.unique(idx):                 expert_mask = (idx == expert_idx)                 if expert_mask.sum() == 0:                     continue                 x_sel = x[expert_mask]                 y_sel = self.experts[expert_idx](x_sel)                 gate_weight = gates[expert_mask, i].unsqueeze(1)                 output[expert_mask] += gate_weight * y_sel                  self.expert_usage[expert_idx] += expert_mask.sum()          aux_loss = None         if self.use_aux_loss and self.training:             usage = self.expert_usage / batch_size             aux_loss = ((usage - usage.mean()) ** 2).mean() * self.aux_loss_weight          return output, aux_loss 

Первое на что обратим внимание это это warmup_iters. Тут у нас это число итераций где мы как-бы отключаем gating-сеть , чтобы избежать коллапса распределения (один эксперт выбирается чаще остальных до того, как сеть обучится разумно маршрутизировать входы). Это дает нам «разогреть» экспертов передавая им равномерно данные и далее мы уже начинаем более тонко избирать экспертов за счет gating network.

Второй момент это добавление use_aux_loss. Данный параметр позволяет нам учитывать в общем лоссе неравномерное распределение по экспертам в общий loss.

Как итог модель выбирает tok_k экспертов и на основе их предсказаний делает взвешенную сумму, после чего выдает результат и loss по распределению.

Что в итоге?

При простом «на коленке» мы смогли получить f1 на тесте 89.%. Более явно поиграв с гипперпараметрами, типами экспертов и некоторыми изощренностями думаю, что можно получить результат лучше. Самое главное, что

Давайте теперь проведем модель через один батч и посмотрим, что там на одном батче, что произошло по графику и посмотрим на первые 10 сэмплов батча.

Как можем увидеть, у нас 2 эксперт оказался в данной итерации не востребован, а использовали мы с 0,1,3 эксперта в разной пропорции.

Можно сказать, что вот: «второй эксперт переобучился или обучился плохо». Однако давайте глянем глубже! Мы ведь отслеживаем все через clearml 🙂

На тестовом датасете в среднем можно заметить, что все эксперты примерно вышли на какую-то свою зону ответственности. Хотя конечно от первого эксперта хотелось ожидать побольше!

Теперь давайте посмотрим на визуализацию результатов:

Несмотря на шакальность(мы работаем с CIFAR100 напоминаю) мы получили весьма неплохие результаты. И теперь вишенка на торте — это отслеживание по экспертам. Их собственно говоря мы итак логируем и сейчас можем провести на маленьком сэмпле аналитику. Если у нас есть очень большой эксперт и он не пригодился в использовании в вычислениях, то мы можем сэкономить очень много памяти.

Подводя итоги основным концептом было показать какие проблемы бывают и сложности при работе с технологией, а также ее возможности и потенциал, который уже сейчас очень успешно реализуется!

🔥 Ставьте лайк и напишите какие темы было бы интересно разобрать дальше! Самое главное — пробуйте и экспериментируйте!

✔️ Присоединяйтесь к нашему Telegram-сообществу @datafeeling, чтобы первыми применять на практике передовые технологии!


ссылка на оригинал статьи https://habr.com/ru/articles/902728/