Часть 3: Diffusion Transformer (DiT) — Stable Diffusion 3 как она есть

от автора

Обо мне

Привет, меня зовут Василий Техин. В первой статье мы разобрали ResNet, во второй — ViT. Теперь погрузимся в мир генерации изображений с Diffusion Transformer (DiT) — сердцем Stable Diffusion 3.


Пролог: От распознавания к созданию

Представьте нейросеть как художника. Раньше она только анализировала картины («Это Ван Гог!»). Теперь она создаёт шедевры в стиле Ван Гога и не только!

Изображения из статьи

Изображения из статьи

Ключевые этапы работы DiT:

  1. Обучение:

    • Сжимаем изображение в латентное пространство через VAE (256х256х3 → 32х32х4)

    • Добавляем шум за 1000 шагов (чтобы модель училась удалять шум постепенно)

    • DiT учится предсказывать шум на каждом шаге

  2. Генерация (инференс):

    • Начинаем с чистого шума

    • Постепенно удаляем шум за 1000 шагов

    • Декодируем результат через VAE


Пайплайн обучения и генерации

1. Подготовка данных (VAE)

VAE (Variational Autoencoder) сжимает изображение:

# Для изображения 256x256: original = (3, 256, 256) → latent = (4, 32, 32)  # Сжатие в 64 раза 

Зачем? DiT работает с 32×32×4 латентными векторами — экономия вычислений!

2. Прямой процесс (добавление шума)

Процесс зашумления

Процесс зашумления

1000 шагов постепенного зашумления по формуле:

def forward_diffusion(z0, t, T=1000):     alpha_t = cos((t/T + 0.008) / 1.008 * π/2)**2       noise = torch.randn_like(z0)  # Случайный шум     z_t = sqrt(alpha_t) * z0 + sqrt(1-alpha_t) * noise  # Зашумленная версия     return z_t, noise 

Где:

  • z0 — исходный латентный вектор изображения

  • t — текущий шаг (1-1000)

  • noise — добавленный шум

3. Обратный процесс (обучение DiT)

Ключевые шаги обучения:

  1. Выбираем случайный шаг t (1-1000)

  2. Зашумляем латентный вектор: z_t, real_noise = forward_diffusion(z0, t)

  3. Подаем в DiT: pred_noise = DiT(z_t, t, text_embed) и получаем предсказанный шум

  4. Считаем MSE-лосс: loss = (real_noise - pred_noise).square().mean()

  5. Обновляем веса через backpropagation

Обратите внимание: DiT учится предсказывать оригинальный шум, а не изображение!

4. Генерация изображений (инференс)

Пошаговый процесс для Stable Diffusion 3(отличается от DiT из оригинальной статьи тем, что подается эмбеддинг текста вместо метки класса):

def generate(prompt, steps=1000):     # 1. Текстовый эмбеддинг     text_embed = text_encoder(prompt)  # [1, 768]          # 2. Начальный шум     z = torch.randn(1, 4, 32, 32)  # z_T          # 3. Итеративное удаление шума     for t in range(steps, 0, -1):         # a) Предсказание шума DiT         pred_noise = DiT(z, t, text_embed)                  # b) Classifier-Free Guidance (CFG) - усиление текстового влияния         if cfg_scale > 1.0:             uncond_embed = text_encoder("")  # Пустой промпт             uncond_noise = DiT(z, t, uncond_embed)             pred_noise = uncond_noise + cfg_scale * (pred_noise - uncond_noise)                  # c) Формула обратного шага (DDIM)         alpha_t = cos((t/steps + 0.008)/1.008 * π/2)**2         alpha_prev = cos(((t-1)/steps + 0.008)/1.008 * π/2)**2         z = (z - (1 - alpha_t)/sqrt(1 - alpha_t) * pred_noise) / sqrt(alpha_t)         z += sqrt(1 - alpha_prev) * torch.randn_like(z)  # Стохастичность              # 4. Декодирование через VAE     return VAE.decode(z)  # [1, 3, 256, 256] 

DiT в деталях: Отличия от ViT

1. Patchify: Работа с латентами

Мы нарезаем на патчи не оригинальное изображение, а латентный вектор

# Для латента 32x32x4 с патчами 2x2: self.patch_embed = nn.Conv2d(4, dim, kernel_size=2, stride=2) # → [batch, 256, dim]  (16*16=256 патчей) 

Сравнение с ViT: ViT работает с пикселями, DiT — с латентными векторами.

2. Classifier-Free Guidance (CFG)

Механизм усиления текста мы хотим, чтобы изображение из шума соответсвовало тексту, который мы передали:

pred_noise = uncond_noise + guidance_scale * (text_noise - uncond_noise) 

Где:

  • uncond_noise — предсказание для пустого промпта

  • text_noise — предсказание для целевого промпта

  • guidance_scale (7-10) — сила влияния текста

3. Cross-Attention Block

В SD3 (не в оригинальном DiT):

class CrossAttentionBlock(nn.Module):     def forward(self, x, text_emb):         # Проекция текста         q = self.wq(x)  # [batch, tokens, dim]         k = self.wk(text_emb)  # [batch, text_tokens, dim]         v = self.wv(text_emb)                  # Attention         attn = softmax(q @ k.transpose(-2,-1) / sqrt(dim))         return attn @ v  # Текст-условные признаки 

Зачем? Точнее связывает текст и визуальные патчи.

4. In-Context Conditioning

Механизм в DiT-XL:

  • Ввод текстовых токенов как патчей

  • Пример: [IMAGE_PATCH1, TEXT_TOKEN1, IMAGE_PATCH2, ...]

  • Позволяет смешивать текст и изображение на входе

5. AdaLN-Zero

Улучшение в DiT-2:

  • Инициализация параметров γ в AdaLN нулями

  • Первые шаги обучения: AdaLN = Identity Function

  • Стабилизирует раннее обучение


Разберём на простом примере

Оригинальный DiT (класс-условный)

Uncurated 512 × 512 DiT-XL/2 samples. Classifier-free guidance scale = 2.0 Class label = “panda” (388)

Uncurated 512 × 512 DiT-XL/2 samples. Classifier-free guidance scale = 2.0 Class label = “panda” (388)
# Генерация "собаки" (класс 207) class_label = 207 z = torch.randn(1, 4, 32, 32)  # Начальный шум  for t in range(1000, 0, -1):     pred_noise = DiT(z, t, class_label)  # Прямой вызов     z = update_step(z, pred_noise, t)  # Обновление латента 

Особенности:

  • Простой ввод класса вместо текста

  • Нет CFG и Cross-Attention

Stable Diffusion 3 (текст-условный)

prompt = "пудель в розовой шапочке" text_embed = text_encoder(prompt)  # [1, 768]  for t in range(1000, 0, -1):     # 1. Основное предсказание     pred_noise = DiT(z, t, text_embed)          # 2. Classifier-Free Guidance (усиление текста)     uncond_noise = DiT(z, t, text_encoder(""))     pred_noise = uncond_noise + 7.5 * (pred_noise - uncond_noise)          # 3. Cross-Attention в некоторых блоках     # (см. архитектуру ниже) 

Нововведения SD3(относительно DiT):

  • Текст через T5 вместо классов

  • CFG с масштабом 7.5 для точного следования промпту


Оценка качества: Метрики

1. FID (Fréchet Inception Distance)

Как работает:

  1. Берем 50k реальных и 50k сгенерированных изображений

  2. Пропускаем через Inception-v3 (получаем признаки)

  3. Считаем «расстояние» между распределениями:

FID = ||μ_real - μ_gen||^2 + Tr(Σ_real + Σ_gen - 2(Σ_real Σ_gen)^{1/2}) 

Интерпретация:

  • FID = 0 — идеальное совпадение

  • FID < 5 — фотореалистичные изображения

  • DiT-XL: FID = 2.27 (ImageNet 256×256)

2. IS (Inception Score)

IS = exp(E_x[KL(p(y|x) || p(y))]) 

Где:

  • p(y|x) — распределение классов для изображения

  • p(y) — общее распределение классов

  • Высокий IS = разнообразные и узнаваемые изображения


Почему DiT — это будущее Stable Diffusion?

✅ Преимущества перед U-Net:

Параметр

U-Net (SD 2.1)

DiT (SD 3)

Качество (FID)

3.85

2.27

Масштабируемость

Ограничена

Линейный рост

Разрешение

768×768

1024×1024

Текстовая привязка

Средняя

Точная

❌ Ограничения:

  1. Ресурсы: Обучение DiT-XL требует 500,000 GPU-hours

  2. Память: Генерация 1024px требует 48GB VRAM


Философский итог

DiT объединяет три революции ИИ:

  1. Сжатие данных (VAE)

  2. Трансформеры (ViT)

  3. Диффузионные процессы


Проверь себя

  1. Почему DiT работает с 32×32, а не 256×256?

  2. Как Classifier-Free Guidance улучшает генерацию?


Резюме

Diffusion Transformer (DiT):

  • Работает в латентном пространстве VAE (32x32x4)

  • Заменяет U-Net на трансформер с AdaLN

  • Оригинал: класс-условная генерация

Stable Diffusion 3:

  • Текст через текст энкодер и Cross-Attention

  • Classifier-Free Guidance для точности

  • Поддержка 1024px изображений

  • FID 2.27 — новый стандарт качества

Ссылки:

  1. Оригинальная статья DiT

  2. Stable Diffusion 3

  3. CFG в диффузионных моделях


ссылка на оригинал статьи https://habr.com/ru/articles/924410/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *