Законы масштабирования дистилляции

от автора

Рекомендация для читателей:

Прежде чем погрузиться в детали, советую ознакомиться с двумя отличными статьями инженера из Яндекса (статья 1статья 2). В них отлично объясняются принципы дистилляции, её применение в промышленных задачах и ключевые практические аспекты. Это идеальный старт для тех, кто только начинает знакомиться с темой.

Однако, если вы, как и я, стремитесь к глубокому пониманию — этого может оказаться недостаточно. В данном обзоре мы пойдём дальше:

  1. Математическая формализация: Разберём более глубако уравнения, лежащие в основе дистилляции, включая функцию потерь с температурным параметром, оптимизацию распределений и законы масштабирования из работы Apple.

  2. Примеры кода: Покажем, как реализовать дистилляцию на практике — от простых моделей на PyTorch до тонкой настройки гиперпараметров.

  3. Нюансы исследований: Ответим на вопросы, оставшиеся за рамками вводных материалов. Например, почему «слишком умный учитель» вредит ученику и как математически обосновать оптимальное соотношение их размеров.

Для кого это?
Если вы хотите не просто использовать дистилляцию «из коробки», а понимать, как и почему она работает — этот разбор для вас. Мы заглянем «под капот» методов, чтобы вы могли осознанно применять их в своих проектах.

p.s. Если хотите сразу разобрать законы масштабирования дистилляции, то пропускайте первую часть и сразу ко второй.

Part 1: Knowledge Distillation

Knowledge Distillation (Дистилляция знаний) — это метод обучения моделей-студентов (обычно меньшего размера и менее сложных) путем передачи «знаний» от предварительно обученной модели-учителя (обычно большей и более сложной). Основная идея заключается в том, что модель-учитель, обладающая большей емкостью и обученная на большом объеме данных, может передать не только свои «жесткие» предсказания (например, класс объекта), но и более богатую информацию о распределении вероятностей классов, которую модель-студент может использовать для более эффективного обучения.

Teacher и Student модели:

В парадигме Knowledge Distillation участвуют две основные модели:

  • Teacher (Учитель): Это большая, предварительно обученная модель, которая считается «экспертом» в решении определенной задачи. Учитель уже достиг высокой точности и обладает «знаниями», которые мы хотим передать студенту. Математически учитель представляется как функция p(y∣x), которая для входных данных xx выдает распределение вероятностей p по классам y.

  • Student (Студент): Это меньшая, более простая модель, которую мы хотим обучить. Цель студента — научиться имитировать поведение учителя, чтобы достичь сравнимой производительности, но при этом быть более эффективной с точки зрения вычислительных ресурсов, памяти или времени инференса. Студент представляется как функция q_θ​(y∣x), где θ — параметры модели, которые мы оптимизируем в процессе обучения.

Функция потерь (Loss Function) в Knowledge Distillation:

Общая цель Knowledge Distillation — минимизировать разницу между предсказаниями учителя и студента. Это формализуется через функцию потерьL, которая зависит от предсказаний учителяp(y|x) и студента q_{\theta}(y|x). Процесс обучения заключается в поиске оптимальных параметров $\theta$ для студента, которые минимизируют эту функцию потерь:

L(p(y|x), q_{\theta}(y|x)) \rightarrow \min_{\theta}

Это общее выражение, и конкретный вид функции потерь и способ дистилляции определяют различные подходы. Рассмотрим два основных подхода: hard-label и soft-label дистилляцию.

Это общее выражение, и конкретный вид функции потерь и способ дистилляции определяют различные подходы. Рассмотрим два основных подхода: hard-label и soft-label дистилляцию.

Hard-label Distillation для GPT моделей: объяснение на пальцах

Представьте, что у нас есть две модели:

  • Учитель (Teacher): Большая, мощная GPT модель, например, GPT-3 или что-то подобное. Она обладает огромным количеством знаний о языке и мире, и способна генерировать очень качественный и связный текст.

  • Студент (Student): Маленькая, более компактная GPT модель, например, уменьшенная версия GPT или Transformer меньшего размера. Она менее ресурсоемкая, но изначально уступает учителю в качестве генерации текста.

Наша цель — «научить» маленькую модель-студента генерировать текст так же хорошо, как и большая модель-учитель, используя метод Hard-label Distillation.

Шаги Hard-label Distillation в этом контексте:

  1. Генерация «жестких» меток учителем (Большой GPT):

    • Мы берем большой набор текстовых данных (например, обучающую выборку, на которой изначально обучался учитель, или просто большой корпус текстов).

    • Для каждого фрагмента текста (или запроса) из этого набора, мы просим большую модель-учителя сгенерировать текст. В контексте GPT, это означает, что мы подаем учителю входной текст (например, начало предложения или запрос) и просим его сгенерировать продолжение.

    • Учитель генерирует последовательность токенов, которые он считает наиболее вероятными для продолжения данного текста. Эти сгенерированные последовательности токенов и являются нашими «жесткими» метками.

    Пример:

    • Входной текст (запрос): «Столица Франции — это»

    • Учитель (Большая GPT) генерирует: «Париж.» (токены: «Па», «ри», «ж», «.»)

    • «Жесткая» метка: Последовательность токенов: («Па», «ри», «ж», «.»)

    Мы повторяем этот процесс для большого количества различных входных текстов, получая набор пар: (исходный входной текст, «жесткая» метка — последовательность токенов, сгенерированная учителем).

  2. Обучение студента (Маленький GPT) на «жестких» метках:

    • Теперь у нас есть синтетический датасет, состоящий из пар (исходный входной текст, «жесткая» метка). Мы будем использовать этот датасет для обучения маленькой модели-студента.

    • Мы обучаем студента предсказывать «жесткие» метки, сгенерированные учителем, используя стандартную задачу языкового моделирования. Это означает, что для каждого входного текста мы хотим, чтобы студент генерировал последовательность токенов, максимально похожую на «жесткую» метку, сгенерированную учителем.

    • В процессе обучения мы используем функцию потерь кросс-энтропии. Мы сравниваем распределение вероятностей токенов, предсказанное студентом, с «жесткой» меткой (которая по сути является распределением, где вероятность «правильного» токена равна 1, а всех остальных — 0). Мы стремимся минимизировать эту кросс-энтропию, заставляя студента «подражать» учителю в предсказании токенов.

    В нашем примере, если студент на вход «Столица Франции — это» предсказывает, например, «Лондон», то функция потерь будет высокой, так как «жесткая» метка учителя была «Париж». В процессе обучения студент будет корректировать свои параметры, чтобы в будущем для аналогичных запросов предсказывать «Париж» или что-то очень похожее на предсказание учителя.

Почему маленькая модель может предсказывать те же токены, что и большая?

  • Передача знаний через «жесткие» метки: Хотя Hard-label Distillation и теряет часть информации из распределения вероятностей учителя, она все равно эффективно передает ключевые знания о том, какие токены являются наиболее вероятными в определенных контекстах. Большая модель, будучи хорошо обученной, «знает», какие продолжения текста являются грамматически правильными, семантически уместными и стилистически подходящими. Генерируя «жесткие» метки, она как бы «подсказывает» маленькой модели, какие именно токены нужно предсказывать.

  • Фокус на наиболее важной информации: «Жесткие» метки концентрируются на наиболее вероятных токенах. В языковом моделировании часто бывает так, что для многих контекстов есть один или несколько доминирующих «правильных» продолжений. Hard-label Distillation помогает маленькой модели быстро освоить эти наиболее важные закономерности, игнорируя менее значимые детали, которые могут быть избыточными для достижения хорошего качества генерации.

  • Упрощение задачи обучения: Обучение на «жестких» метках превращает дистилляцию в стандартную задачу обучения с учителем. Это упрощает процесс обучения и позволяет использовать хорошо известные методы и оптимизаторы. Маленькой модели не нужно пытаться воспроизвести все тонкости распределения вероятностей учителя, ей достаточно научиться предсказывать наиболее вероятные токены, что является более простой задачей.

Важно отметить ограничения Hard-label Distillation:

  • Потеря «мягкой» информации: Как и указано в тексте, Hard-label Distillation теряет информацию о вероятностях других классов и «мягких» отношениях между классами. В контексте языковых моделей это означает, что студент может не улавливать все нюансы стиля, семантики и разнообразия, которые присутствуют в распределении вероятностей учителя. Например, учитель может знать, что «Париж» является самым вероятным ответом на «Столица Франции — это», но также понимать, что «Рим» или «Берлин» являются менее вероятными, но все же допустимыми ответами в определенных контекстах. Hard-label Distillation фокусируется только на «Париже», игнорируя эту «мягкую» информацию.

  • Потенциальное ухудшение разнообразия: Из-за фокусировки на «жестких» метках, студент может стать менее разнообразным в своих генерациях, чем учитель. Он может слишком точно копировать наиболее вероятные ответы учителя, упуская возможность генерировать альтернативные, но все еще качественные варианты.

Математическая формализация:

1. Генерация «жестких» меток учителем: Для каждого примераx^{(n)}из обучающей выборки, учительp(y|x)предсказывает распределение вероятностей классов. «Жесткая» меткаy^{(n)}выбирается как класс с максимальной вероятностью, предсказанной учителем. В контексте языков моделей, гдеyпредставляет собой последовательность токенов, учитель генерирует последовательность «жестких» метокy^{(1)}, \ldots y^{(N)}дляNпримеров. Здесьy^{(n)} = (y_1^{(n)}, \ldots, y_{T_n}^{(n)})представляет собой последовательность токенов длинойT_n.

y^{(1)}, \ldots y^{(N)} \sim p(y|x)

В более простом варианте, для классификации, y^{(n)} = \arg\max_{y} p(y|x^{(n)}). В случае последовательностей, учитель может генерировать целые последовательности наиболее вероятных токенов.

2. Обучение студента на «жестких» метках: Студентq_{\theta}(y|x)обучается максимизировать логарифмическую вероятность «жестких» меток, сгенерированных учителем. Это стандартная задача обучения с учителем, где целевыми метками являютсяy^{(1)}, \ldots y^{(N)}. Функция потерь, которую мы минимизируем (или эквивалентно, максимизируем отрицательную потерю), представляет собой ожидание логарифмической вероятности «жестких» меток под распределениемp(y|x)учителя.

\mathbb{E}_{p(y|x)} [\log q_{\theta}(y|x)] \rightarrow \max_{\theta}

В практической реализации, это ожидание аппроксимируется эмпирическим средним по обучающей выборке. Для последовательностей текста, функция потерь выглядит следующим образом:

\frac{1}{N} \sum_{n=1}^{N} \sum_{t=1}^{T_n} \log q_{\theta}(y_t^{(n)}|y_{<t}^{(n)})

Здесь:

* N — количество примеров в обучающей выборке.

* T_n — длина последовательности для $n$-го примера.

* y_t^{(n)}t-й токен в последовательности «жестких» меток дляn-го примера, сгенерированных учителем.

* y_{<t}^{(n)} = (y_1^{(n)}, \ldots, y_{t-1}^{(n)}) — префикс последовательности доt-го токена.

* q_{\theta}(y_t^{(n)}|y_{<t}^{(n)}) — вероятность предсказания студентомt-го токенаy_t^{(n)} при условии предыдущих токеновy_{<t}^{(n)}, параметризованная\theta.

Эта функция потерь представляет собой кросс-энтропию между распределением «жестких» меток, сгенерированных учителем, и предсказаниями студента. Мы стремимся максимизировать эту величину, что эквивалентно минимизации отрицательной логарифмической правдоподобности или кросс-энтропии.

Преимущества и недостатки Hard-label Distillation:

  • Преимущества: Простота реализации и понимания. Можно использовать стандартные методы обучения с учителем.

  • Недостатки: Потеря информации, содержащейся в распределении вероятностей учителя. «Жесткие» метки содержат только информацию о наиболее вероятном классе, игнорируя вероятности других классов и «мягкие» отношения между классами, которые учитель «знает». Это может ограничить эффективность передачи знаний.

Реализация Hard-label Distillation на основе Open R1

Ниже представлена реализация Hard-label Distillation с использованием подхода, применяемого в проекте Open R1. Процесс разделен на два этапа: генерация данных учителем и обучение ученика.

@misc{openr1,     title = {Open R1: A fully open reproduction of DeepSeek-R1},     url = {https://github.com/huggingface/open-r1},     author = {Hugging Face},     month = {January},     year = {2025} } 

Этап 1: Генерация «жестких» меток большой моделью (учителем)

import argparse from datasets import load_dataset from typing import Optional, Dict, Any  from distilabel.pipeline import Pipeline from distilabel.models import vLLM from distilabel.steps.tasks import TextGeneration  def build_hard_label_pipeline(     teacher_model: str,     base_url: str = "http://localhost:8000/v1",     prompt_column: Optional[str] = None,     prompt_template: str = "{{ instruction }}",     temperature: float = 0.0,     max_new_tokens: int = 4096,     input_batch_size: int = 32, ) -> Pipeline:     """     Description:     ---------------         Создает конвейер для генерации "жестких" меток с использованием модели-учителя.      Args:     ---------------         teacher_model: Идентификатор модели-учителя         base_url: URL сервера vLLM         prompt_column: Имя колонки в датасете, содержащей входные тексты         prompt_template: Шаблон для форматирования промптов         temperature: Температура для генерации (0.0 для "жестких" меток)         max_new_tokens: Максимальное количество генерируемых токенов         input_batch_size: Размер батча для входных данных      Returns:     ---------------         Настроенный конвейер Distilabel      Raises:     ---------------         Exception: В случае ошибки настройки конвейера      Examples:     ---------------         >>> pipeline = build_hard_label_pipeline("deepseek-ai/DeepSeek-R1")         >>> pipeline.run(dataset)     """     # Настраиваем параметры генерации с temperature=0 для получения детерминированных ответов     generation_kwargs: Dict[str, Any] = {         "max_new_tokens": max_new_tokens,         "temperature": temperature,         "top_p": 1.0,         "do_sample": False,          # Отключаем семплирование для получения "жестких" меток     }      with Pipeline(         name="hard-label-distillation",         description="Конвейер для генерации 'жестких' меток с использованием модели-учителя",     ) as pipeline:         # Настраиваем модель-учителя через vLLM         teacher = vLLM(             model=teacher_model,             tokenizer=teacher_model,             extra_kwargs={                 "tensor_parallel_size": 1,               # Можно увеличить для больших моделей                 "max_model_len": max_new_tokens + 2048,  # Добавляем запас для контекста             },             generation_kwargs=generation_kwargs,         )          # Настраиваем шаг генерации текста         text_generation = TextGeneration(             llm=teacher,             template=prompt_template,             num_generations=1,           # Для "жестких" меток нам нужна только одна генерация             input_mappings={"instruction": prompt_column} if prompt_column is not None else {},             input_batch_size=input_batch_size,         )      return pipeline  def generate_hard_labels(     dataset_name: str,     dataset_split: str = "train",     teacher_model: str = "deepseek-ai/DeepSeek-R1",     output_dataset: str = "my-username/hard-label-distill-dataset",     prompt_column: str = "problem",     prompt_template: str = "You will be given a problem. Please reason step by step, and put your final answer within \\boxed{}: {{ instruction }}",     max_examples: Optional[int] = None,     private: bool = False, ) -> Any:     """     Description:     ---------------         Генерирует "жесткие" метки с использованием модели-учителя и сохраняет результаты как набор данных на HuggingFace Hub.      Args:     ---------------         dataset_name: Имя исходного датасета         dataset_split: Имя сплита датасета         teacher_model: Модель-учитель для генерации "жестких" меток         output_dataset: Имя выходного датасета на HuggingFace Hub         prompt_column: Имя колонки, содержащей входные данные         prompt_template: Шаблон для форматирования промптов         max_examples: Максимальное количество примеров для обработки         private: Приватный ли выходной датасет      Returns:     ---------------         Датасет с "жесткими" метками      Raises:     ---------------         Exception: В случае ошибки генерации меток      Examples:     ---------------         >>> hard_label_dataset = generate_hard_labels("my-dataset", "train")         >>> hard_label_dataset.push_to_hub("my-username/hard-label-dataset")     """     # Загружаем исходный датасет     print(f"Загрузка датасета '{dataset_name}' (сплит: {dataset_split})...")     dataset = load_dataset(dataset_name, split=dataset_split)      # Ограничиваем количество примеров, если указано     if max_examples is not None and max_examples < len(dataset):         dataset = dataset.select(range(max_examples))      print(f"Создание конвейера для генерации 'жестких' меток с использованием {teacher_model}...")     pipeline = build_hard_label_pipeline(         teacher_model=teacher_model,         prompt_column=prompt_column,         prompt_template=prompt_template,     )      print(f"Запуск конвейера для генерации 'жестких' меток на {len(dataset)} примерах...")     # Генерируем "жесткие" метки     hard_label_dataset = pipeline.run(dataset=dataset)      # Сохраняем результаты на HuggingFace Hub     if output_dataset:         print(f"Сохранение результатов в '{output_dataset}'...")         hard_label_dataset.push_to_hub(output_dataset, private=private)         print(f"Датасет с 'жесткими' метками успешно сохранен в '{output_dataset}'.")      return hard_label_dataset  if __name__ == "__main__":     parser = argparse.ArgumentParser(description="Генерация 'жестких' меток с использованием модели-учителя")     parser.add_argument("--dataset", type=str, required=True, help="Имя исходного датасета")     parser.add_argument("--split", type=str, default="train", help="Сплит датасета")     parser.add_argument("--teacher-model", type=str, default="deepseek-ai/DeepSeek-R1", help="Модель-учитель")     parser.add_argument("--output-dataset", type=str, required=True, help="Имя выходного датасета")     parser.add_argument("--prompt-column", type=str, default="problem", help="Колонка с входными данными")     parser.add_argument("--prompt-template", type=str,                        default="You will be given a problem. Please reason step by step, and put your final answer within \\boxed{}: {{ instruction }}",                        help="Шаблон для форматирования промптов")     parser.add_argument("--max-examples", type=int, default=None, help="Максимальное количество примеров")     parser.add_argument("--private", action="store_true", help="Сделать выходной датасет приватным")      args = parser.parse_args()      generate_hard_labels(         dataset_name=args.dataset,         dataset_split=args.split,         teacher_model=args.teacher_model,         output_dataset=args.output_dataset,         prompt_column=args.prompt_column,         prompt_template=args.prompt_template,         max_examples=args.max_examples,         private=args.private,     )

Этап 2: Обучение модели-ученика на «жестких» метках

import logging import os import sys from dataclasses import dataclass, field from typing import Optional, Dict, Any  import datasets import torch import transformers from datasets import load_dataset from transformers import AutoTokenizer, set_seed from transformers.trainer_utils import get_last_checkpoint  from trl import SFTTrainer, ModelConfig, TrlParser, get_peft_config from open_r1.configs import SFTConfig from open_r1.utils.wandb_logging import init_wandb_training  logger = logging.getLogger(__name__)  @dataclass class HardLabelDistillConfig(SFTConfig):     """Конфигурация для обучения ученика с использованием Hard-label Distillation."""      dataset_name: str = field(         default=None, metadata={"help": "Датасет с 'жесткими' метками, сгенерированными учителем"}     )     input_column: str = field(         default="problem", metadata={"help": "Колонка с входными данными"}     )     target_column: str = field(         default="generation_0", metadata={"help": "Колонка с выходными данными (жесткими метками) учителя"}     )     max_seq_length: int = field(         default=2048, metadata={"help": "Максимальная длина последовательности"}     )  def train_student_model(config: HardLabelDistillConfig, model_args: ModelConfig) -> None:     """     Description:     ---------------     Обучает модель-ученика на 'жестких' метках, сгенерированных учителем.      Args:     ---------------         config: Конфигурация обучения         model_args: Конфигурация модели      Returns:     ---------------         None      Raises:     ---------------         Exception: В случае ошибки обучения модели      Examples:     ---------------         >>> train_student_model(config, model_args)     """     # Настраиваем логирование     logging.basicConfig(         format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",         datefmt="%Y-%m-%d %H:%M:%S",         handlers=[logging.StreamHandler(sys.stdout)],     )     log_level = config.get_process_log_level()     logger.setLevel(log_level)     datasets.utils.logging.set_verbosity(log_level)     transformers.utils.logging.set_verbosity(log_level)      # Устанавливаем сид для воспроизводимости     set_seed(config.seed)      # Проверяем наличие последнего чекпоинта     last_checkpoint: Optional[str] = None     if os.path.isdir(config.output_dir):         last_checkpoint = get_last_checkpoint(config.output_dir)         if last_checkpoint is not None:             logger.info(f"Найден чекпоинт, продолжаем обучение с {last_checkpoint}")      # Инициализируем Weights & Biases, если нужно     if "wandb" in config.report_to:         init_wandb_training(config)      # Загружаем датасет с 'жесткими' метками     logger.info(f"Загрузка датасета с 'жесткими' метками: {config.dataset_name}")     dataset = load_dataset(config.dataset_name)      # Подготавливаем входные данные и метки для обучения     def prepare_dataset(examples: Dict[str, Any]) -> Dict[str, Any]:         """Форматирует данные для обучения с учителем."""         return {             "input_ids": examples[config.input_column],             "labels": examples[config.target_column],         }      # Трансформируем датасет     dataset = dataset.map(prepare_dataset, batched=True)      # Загружаем токенизатор     tokenizer = AutoTokenizer.from_pretrained(         model_args.model_name_or_path,         revision=model_args.model_revision,         trust_remote_code=model_args.trust_remote_code,     )      # Настраиваем chat_template, если указан     if config.chat_template is not None:         tokenizer.chat_template = config.chat_template      # Настраиваем параметры модели     torch_dtype = (         model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)     )     model_kwargs: Dict[str, Any] = dict(         revision=model_args.model_revision,         trust_remote_code=model_args.trust_remote_code,         torch_dtype=torch_dtype,         use_cache=False if config.gradient_checkpointing else True,     )     config.model_init_kwargs = model_kwargs      # Создаем SFT тренер     trainer = SFTTrainer(         model=model_args.model_name_or_path,         args=config,         train_dataset=dataset["train"],         eval_dataset=dataset["validation"] if "validation" in dataset and config.eval_strategy != "no" else None,         processing_class=tokenizer,         peft_config=get_peft_config(model_args),     )      # Запускаем обучение     logger.info("Начало обучения модели-ученика...")     checkpoint: Optional[str] = None     if config.resume_from_checkpoint is not None:         checkpoint = config.resume_from_checkpoint     elif last_checkpoint is not None:         checkpoint = last_checkpoint      train_result = trainer.train(resume_from_checkpoint=checkpoint)     metrics = train_result.metrics     trainer.log_metrics("train", metrics)     trainer.save_metrics("train", metrics)     trainer.save_state()      # Сохраняем модель     logger.info(f"Сохранение модели в {config.output_dir}")     trainer.save_model(config.output_dir)      # Создаем карточку модели и загружаем на HuggingFace Hub, если нужно     kwargs: Dict[str, Any] = {         "dataset_name": config.dataset_name,         "tags": ["hard-label-distillation", "open-r1"],     }      if trainer.accelerator.is_main_process:         trainer.create_model_card(**kwargs)         # Восстанавливаем кэш для быстрого инференса         trainer.model.config.use_cache = True         trainer.model.config.save_pretrained(config.output_dir)      # Оцениваем модель, если нужно     if config.do_eval and "validation" in dataset:         logger.info("Оценка модели...")         metrics = trainer.evaluate()         trainer.log_metrics("eval", metrics)         trainer.save_metrics("eval", metrics)      # Загружаем модель на HuggingFace Hub, если нужно     if config.push_to_hub:         logger.info("Загрузка модели на HuggingFace Hub...")         trainer.push_to_hub(**kwargs)  if __name__ == "__main__":     # Создаем парсер аргументов     parser = TrlParser((HardLabelDistillConfig, ModelConfig))     config, model_args = parser.parse_args_and_config()      # Запускаем обучение     train_student_model(config, model_args)

Пример использования

# Этап 1: Генерация "жестких" меток с использованием модели-учителя python hard_label_distill.py \   --dataset AI-MO/NuminaMath-TIR \   --teacher-model deepseek-ai/DeepSeek-R1 \   --output-dataset username/hard-label-math-dataset \   --prompt-column problem  # Этап 2: Обучение модели-ученика на сгенерированных "жестких" метках accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml train_student.py \   --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \   --dataset_name username/hard-label-math-dataset \   --input_column problem \   --target_column generation_0 \   --learning_rate 1.0e-5 \   --num_train_epochs 2 \   --packing \   --max_seq_length 4096 \   --per_device_train_batch_size 8 \   --gradient_accumulation_steps 4 \   --gradient_checkpointing \   --bf16 \   --output_dir models/Qwen2.5-1.5B-Hard-Label-Distill

II. Soft-label Distillation: Дистилляция с использованием «мягких» меток

Концепция:

Soft-label distillation, предложенная Хинтоном и соавторами в их знаменитой статье «Distilling the Knowledge in a Neural Network» (2015), является более совершенным методом дистилляции знаний. В отличие от Hard-label distillation, этот подход использует не только «жесткие» метки, но и полное распределение вероятностей, предсказанное учителем, в качестве «мягких» меток (soft labels).

«Мягкие» метки содержат значительно больше информации, чем «жесткие», поскольку они отражают уверенность учителя в различных классах и отношения между ними. Например, учитель может предсказать для изображения собаки вероятности [0.8 для «собака», 0.15 для «волк», 0.03 для «лиса», 0.02 для других классов]. Эта информация гораздо богаче, чем просто метка «собака».

Ключевым компонентом метода является «temperature scaling» (масштабирование температуры), который делает распределение вероятностей более «мягким» и информативным путем деления логитов модели на параметр температуры T > 1.

Soft-label Distillation для GPT моделей: объяснение на пальцах

Представьте, что у нас есть две модели:

  • Учитель (Teacher): Большая, мощная GPT модель с 175 миллиардами параметров. Она обладает глубоким пониманием языка и мира.

  • Студент (Student): Компактная GPT модель с 1.5 миллиардами параметров. Намного быстрее и экономичнее, но изначально уступает учителю в качестве.

Наша цель — научить студента генерировать текст так же хорошо, как учитель, используя Soft-label Distillation.

Шаги Soft-label Distillation:

  1. Генерация «мягких» меток учителем:

    • Для запроса «Столица Франции — это» большая модель-учитель не просто выдает «Париж», но вычисляет вероятности для всех возможных следующих токенов:

      • «Париж»: 0.92

      • «город»: 0.03

      • «Рим»: 0.01

      • … (и тысячи других токенов с малыми вероятностями)

    • Проблема: это распределение слишком «острое» — один токен имеет почти всю вероятность. Чтобы извлечь больше полезных знаний, применяем temperature scaling:

    • Делим логиты на температуру T (например, T = 2.0) перед применением softmax:

      • «Париж»: 0.70 (уменьшилось с 0.92)

      • «город»: 0.08 (увеличилось с 0.03)

      • «Рим»: 0.05 (увеличилось с 0.01)

      • … (другие токены тоже получают больше вероятности)

    • Эти «смягченные» распределения сохраняют намного больше информации о том, что модель-учитель «знает».

  2. Обучение модели-студента:

    • Студент обучается не только предсказывать правильный токен, но и воспроизводить всё распределение вероятностей учителя.

    • Для этого используется КЛ-дивергенция (или кросс-энтропия) между распределениями учителя и студента.

    • Важно: распределение студента также «смягчается» с той же температурой T для сопоставимости.

    • Функция потерь умножается на T² для компенсации уменьшения градиентов.

  3. Комбинированное обучение:

    • Обычно используется комбинация двух функций потерь:

      • α · (Потери от «мягких» меток) + (1-α) · (Стандартные потери от «жестких» меток)

    • Где α — коэффициент, обычно от 0.5 до 0.9

Почему это работает лучше Hard-label Distillation?

  • «Темные знания» (Dark Knowledge): Как назвал Хинтон, относительные вероятности «неправильных» ответов содержат ценную информацию. Например, если модель путает «собаку» с «волком», но не с «самолетом», это важная информация.

  • Передача неопределенности: Студент учится не только правильным ответам, но и тому, в каких случаях стоит сомневаться.

  • Более богатый сигнал: Вместо одного бита информации на каждый пример (правильный/неправильный класс), студент получает информацию о всем распределении вероятностей.

Математическая формализация:

1. «Мягкие» метки учителя с температурой T:

Еслиz_i— логит для класса (токена)iот учителя, то «мягкая» метка с температурой T:

p_i^T = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}

Разберем каждый элемент формулы:

* p_i^T: Это «мягкая» вероятность дляi-го токена, с учетом температурыT. Именно это распределение вероятностей, сгенерированное учителем, мы будем использовать как «мягкую метку».

* z_i: Это логит (logit) дляi-го токена, выданный моделью-учителем. Логиты — это значения, которые модель выдает перед применением функции softmax. Они представляют собой «сырые» оценки того, насколько модель уверена в каждом токене. Чем больше логит, тем больше уверенность модели в этом токене.

* T: Это параметр температуры (temperature). Как мы разбирали уже выше, температура используется для «смягчения» распределения вероятностей.

* \exp(x): Это экспоненциальная функцияe^x.

* \sum_j \exp(z_j/T): Это сумма экспоненциальных значений логитов, деленных на температуру, для всех возможных токеновj. Эта сумма используется для нормализации, чтобы вероятности в итоге суммировались к 1.

Пошаговое объяснение:

1. Деление логитов на температуруz_i/T: Когда мы делим логиты на температуруT > 1, мы уменьшаем абсолютные значения логитов.

2. Экспоненцирование\exp(z_i/T): Экспоненциальная функция преобразует логиты в положительные значения.

3. Нормализация\frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}: Деление на сумму экспоненциальных значений всех логитов гарантирует, что полученные значенияp_i^T будут представлять собой вероятностное распределение, то есть будут неотрицательными и в сумме дадут 1. Это стандартная операция softmax, но с применением температуры.

Интуиция и эффект температуры:

* При высокой температуре (например, T = 2.0), распределение вероятностей становится более «мягким» или «ровным». Вероятности для менее вероятных токенов увеличиваются, а вероятность наиболее вероятного токена уменьшается. Это позволяет «вытащить» больше информации из распределения, включая «темные знания» о менее вероятных, но все же релевантных вариантах.

* При низкой температуре (приближающейся к T = 1.0, или даже меньше), распределение становится более «острым». Вероятность наиболее вероятного токена приближается к 1, а вероятности остальных токенов стремятся к 0. При T=1 это стандартный softmax. ПриT \rightarrow 0 распределение становится дельта-функцией, выбирая только токен с наибольшим логитом.

Figure_1.jpg

Figure_1.jpg

2. Аналогично для студента:

q_i^T = \frac{\exp(z_i^q/T)}{\sum_j \exp(z_j^q/T)}

гдеz_i^q — логит студента для классаi.

Аналогия с формулой учителя: Эта формула абсолютно идентична формуле для учителя, за исключением того, что здесь используются логиты, выданные моделью-студентом.

* q_i^T: «Мягкая» вероятность дляi-го токена, сгенерированная студентом с температуройT.

* z_i^q: Логит дляi-го токена, выданный моделью-студентом.

* Цель: Мы применяем ту же температуруT к распределению студента, чтобы сделать его сопоставимым с «мягкими» метками учителя. Это необходимо для корректного расчета функции потерь дистилляции.

3. Функция потерь для Soft-label Distillation:

L_{soft} = T^2 \cdot \text{KL}(p^T || q^T) = T^2 \cdot \sum_i p_i^T \log\frac{p_i^T}{q_i^T}

МножительT^2 компенсирует уменьшение градиентов из-за temperature scaling.

Разберем компоненты:

* L_{soft}: Функция потерь Soft-label Distillation. Это значение, которое мы хотим минимизировать в процессе обучения студента.

* T^2: Квадрат температуры. Этот множитель используется для масштабирования функции потерь и компенсации уменьшения градиентов, вызванного температурой.

* \text{KL}(p^T || q^T): KL-дивергенция ( Kullback-Leibler divergence) между распределением учителя $p^T$ и распределением студентаq^T.

* \sum_i p_i^T \log\frac{p_i^T}{q_i^T}: Это развернутая формула KL-дивергенции для дискретных распределений.

Пошаговое объяснение KL-дивергенции:

1. \frac{p_i^T}{q_i^T}: Отношение вероятности учителя к вероятности студента для каждого токенаi. Если студент предсказывает вероятностьq_i^T близкую к вероятности учителяp_i^T, это отношение будет близко к 1.

2. \log\frac{p_i^T}{q_i^T}: Логарифм этого отношения. Если отношение близко к 1, логарифм будет близок к 0. Еслиq_i^Tсильно отличается отp_i^T, логарифм будет иметь большее абсолютное значение (отрицательное, еслиq_i^T > p_i^T, и положительное, еслиq_i^T < p_i^T).

3. p_i^T \log\frac{p_i^T}{q_i^T}: Умножение наp_i^T взвешивает вклад каждого токена в общую дивергенцию. Токены, которые учитель считает более вероятными (высокоеp_i^T), вносят больший вклад в функцию потерь.

4. \sum_i p_i^T \log\frac{p_i^T}{q_i^T}: Суммирование по всем токенамi дает общую KL-дивергенцию. KL-дивергенция измеряет «расстояние» между двумя распределениями вероятностей. В контексте дистилляции, она измеряет, насколько распределение студентаq^T отличается от распределения учителяp^T.

РольT^2:

* Применение температурыT «смягчает» распределения, что может привести к уменьшению величины градиентов при обучении. Умножение наT^2 масштабирует функцию потерь, чтобы компенсировать это уменьшение и сделать градиенты более значимыми, особенно на ранних этапах обучения. Это эмпирическая коррекция, которая помогает стабилизировать и ускорить обучение.

* ЦельL_{soft}: МинимизируяL_{soft}, мы заставляем распределение вероятностей студентаq^T максимально приблизиться к распределению вероятностей учителяp^T. Студент учится не только предсказывать «правильный» токен, но и имитировать всю «манеру мышления» учителя, выраженную в распределении вероятностей.

4. Комбинированная функция потерь:

L = \alpha \cdot L_{soft} + (1-\alpha) \cdot L_{hard}

гдеL_{hard} — стандартная кросс-энтропия с истинными метками,\alpha — коэффициент баланса.

Разберем компоненты:

* L: Общая функция потерь, используемая для обучения студента.

* \alpha: Коэффициент баланса (обычно от 0.5 до 0.9). Он определяет, насколько сильно мы полагаемся на «мягкие» метки учителя по сравнению со стандартными «жесткими» метками.

* L_{soft}: Функция потерь Soft-label Distillation, которую мы разобрали выше.

* L_{hard}: Стандартная функция потерь «жестких» меток, обычно кросс-энтропия между предсказаниями студента и истинными (one-hot) метками.

L_{hard} (Стандартные потери «жестких» меток):

* В обычной задаче обучения языковой модели, мы имеем «жесткие» метки — это истинные следующие токены в обучающих данных. Например, для фразы «Столица Франции — это Париж», «Париж» является «жесткой» меткой.

* L_{hard} вычисляется как кросс-энтропия между распределением вероятностей, предсказанным студентом (обычно сT=1, то есть стандартный softmax), и one-hot вектором, представляющим истинный токен. Эта функция потерь заставляет студента предсказывать именно «правильный» токен.

КомбинированиеL_{soft} иL_{hard}:

* Комбинирование «мягких» и «жестких» потерь позволяет студенту учиться как у учителя (черезL_{soft}), так и из исходных данных (черезL_{hard}).

* Коэффициент\alpha позволяет настроить баланс.

* Высокое\alpha (например, 0.9) означает, что мы больше полагаемся на знания учителя, переданные через «мягкие» метки. Это может быть полезно, когда учитель обладает значительно лучшими знаниями, чем можно извлечь только из «жестких» меток.

* Низкое\alpha (например, 0.5) означает, что мы в равной степени учитываем как знания учителя, так и «жесткие» метки. Это может быть полезно, когда мы хотим, чтобы студент сохранил способность хорошо работать и на исходных данных, а не только имитировал учителя.

Практическая реализация Soft-label Distillation для GPT моделей

Программный код был заимствован из репозитория: https://github.com/arcee-ai/DistillKit

1. Конфигурация дистилляции

Первым шагом необходимо настроить параметры дистилляции, включая температуру и коэффициент баланса между мягкими и жесткими метками:

""" Здесь temperature: 2.0 соответствует параметру T в формулах, который "смягчает" распределение вероятностей, а alpha: 0.5 - это коэффициент α, который определяет соотношение между потерями от мягких и жестких меток. """  config = {     "project_name": "distil-multilayer",    # Название проекта     "dataset": {         "name": "mlabonne/FineTome-100k",   # Название датасета         "split": "train",                   # Раздел датасета для тренировки         "num_samples": 1000,                # Количество образцов для тренировки (можно ограничить)         "seed": 42                          # Значение для инициализации генератора случайных чисел     },     "models": {         "teacher": "arcee-ai/Arcee-Spark",  # Модель учителя         "student": "Qwen/Qwen2-1.5B"        # Модель студента     },     "tokenizer": {         "max_length": 4096,                 # Максимальная длина токенов         "chat_template": (             "{% for message in messages %}"             "{% if loop.first and messages[0]['role'] != 'system' %}"             "{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}"             "{% endif %}"             "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"             "{% endfor %}"             "{% if add_generation_prompt %}"             "{{ '<|im_start|>assistant\n' }}"             "{% endif %}"         )                                    # Шаблон для форматирования сообщений в чате     },     "training": {         "output_dir": "./results",           # Директория для сохранения результатов         "num_train_epochs": 3,               # Количество эпох для тренировки         "per_device_train_batch_size": 1,    # Размер батча для тренировки на одном устройстве         "gradient_accumulation_steps": 8,    # Количество шагов для накопления градиентов         "save_steps": 1000,                  # Шаги для сохранения модели         "logging_steps": 2,                  # Шаги для логирования         "save_total_limit": 2,               # Лимит на количество сохраняемых моделей         "learning_rate": 2e-5,               # Скорость обучения         "weight_decay": 0.01,                # Коэффициент регуляризации         "warmup_ratio": 0.2,                 # Доля шагов для разгона скорости обучения         "lr_scheduler_type": "linear",       # Тип планировщика скорости обучения         "resume_from_checkpoint": None,      # Путь к чекпоинту для возобновления тренировки (если есть)         "fp16": False,                       # Использовать ли 16-битное число с плавающей точкой         "bf16": True,                        # Использовать ли BFloat16         "max_grad_norm": 1.0,                # Максимальная норма градиента         "group_by_length": False             # Группировать ли батчи по длине     },     "distillation": {         "temperature": 2.0,                  # Температура для дистилляции         "alpha": 0.5                         # Коэффициент альфа для дистилляции     },     "model_config": {         "use_flash_attention": True          # Использовать ли Flash Attention     } }

2. Подготовка моделей учителя и студента

Для дистилляции необходимо загрузить как модель-учитель (более крупную), так и модель-студент (более компактную):

import torch from typing import Dict, Any from transformers import AutoModelForCausalLM  def load_models_with_flash_attention(config: Dict[str, Any]) -> Dict[str, AutoModelForCausalLM]:     """     Description:     ---------------         Загружает модели с настройкой флеш-внимания для ускорения.      Args:     ---------------         config: Конфигурация моделей и параметров      Returns:     ---------------         Словарь с загруженными моделями      Raises:     ---------------         KeyError: Если в конфигурации отсутствуют необходимые ключи      Examples:     ---------------         >>> config = {         ...     "model_config": {"use_flash_attention": True},         ...     "models": {"teacher": "teacher_model_path", "student": "student_model_path"}         ... }         >>> load_models_with_flash_attention(config)         {'teacher_model': <transformers.models.model_name.model.ModelName object>,          'student_model': <transformers.models.model_name.model.ModelName object>}     """     # Настройки для загрузки моделей     model_kwargs: Dict[str, Any] = {"torch_dtype": torch.bfloat16}      # Проверка на использование flash attention     if config["model_config"]["use_flash_attention"]:         model_kwargs["attn_implementation"] = "flash_attention_2"      # Загрузка моделей     teacher_model = AutoModelForCausalLM.from_pretrained(config["models"]["teacher"], **model_kwargs)     student_model = AutoModelForCausalLM.from_pretrained(config["models"]["student"], **model_kwargs)      return {"teacher_model": teacher_model, "student_model": student_model}  # Вызов функции models = load_models_with_flash_attention(config)  # Теперь models содержит загруженные модели teacher_model = models["teacher_model"] student_model = models["student_model"]

3. Реализация функции потерь с мягкими метками

Ключевым компонентом является функция потерь Soft-label Distillation. Рассмотрим её реализацию из файла distil_logits.py:

""" Это прямая реализация формулы KL-дивергенции. Обратите внимание на следующие ключевые моменты:  1. Логиты масштабируются температурой T перед применением функций softmax/log_softmax. 2. Потери умножаются на T² для компенсации уменьшения градиентов, как описано в теории. 3. Финальная функция потерь комбинирует мягкие метки (KL-дивергенция) и жесткие метки (original_loss) с коэффициентом α. """  from typing import Any import torch import torch.nn.functional as F  def distillation_loss(     self,     student_logits: torch.Tensor,     teacher_logits: torch.Tensor,     inputs: Any,     original_loss: torch.Tensor,     config: Dict[str, Any] ) -> torch.Tensor:     """     Description:     ---------------         Вычисляет потери дистилляции между логитами студента и учителя.      Args:     ---------------         student_logits: Логиты студента.         teacher_logits: Логиты учителя.         inputs: Входные данные.         original_loss: Исходные потери.         config: Конфигурация моделей и параметров.      Returns:     ---------------         Общие потери, включающие дистилляционные потери и исходные потери.      Raises:     ---------------         KeyError: Если в конфигурации отсутствуют необходимые ключи.      Examples:     ---------------         >>> config = {         ...     "distillation": {"temperature": 2.0, "alpha": 0.5},         ...     "tokenizer": {"max_length": 512}         ... }         >>> student_logits = torch.randn(3, 512)         >>> teacher_logits = torch.randn(3, 512)         >>> inputs = ...         >>> original_loss = torch.tensor(0.5)         >>> distillation_loss(self, student_logits, teacher_logits, inputs, original_loss, config)         tensor(0.25)     """     # Приведение размерностей логитов учителя и студента к одинаковому размеру     student_logits, teacher_logits = pad_logits(         student_logits.to(self.model.device),         teacher_logits.to(self.model.device)     )      # Масштабирование логитов с помощью температуры T     temperature = config["distillation"]["temperature"]     student_logits_scaled = student_logits / temperature     teacher_logits_scaled = teacher_logits / temperature      # Расчёт KL-дивергенции между распределениями учителя и студента     loss_kd = F.kl_div(         F.log_softmax(student_logits_scaled, dim=-1),  # log(q_i^T)         F.softmax(teacher_logits_scaled, dim=-1),      # p_i^T         reduction='batchmean'     ) * (temperature ** 2) / config["tokenizer"]["max_length"]      # Комбинирование потерь от мягких и жестких меток     alpha = config["distillation"]["alpha"]     total_loss = alpha * loss_kd + (1 - alpha) * original_loss      return total_loss

4. Обработка различных размеров словарей

Поскольку модели учителя и студента могут иметь разный размер словаря токенов, необходима дополнительная функция для согласования размерности их логитов:

""" Эта функция добавляет нулевые логиты к меньшему распределению, чтобы обеспечить одинаковую размерность для сравнения распределений. """  from typing import Tuple import torch  def pad_logits(     student_logits: torch.Tensor,     teacher_logits: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]:     """     Description:     ---------------         Приводит размерности логитов студента и учителя к одинаковому размеру.      Args:     ---------------         student_logits: Логиты студента.         teacher_logits: Логиты учителя.      Returns:     ---------------         Кортеж из логитов студента и учителя с одинаковыми размерностями.      Raises:     ---------------         ValueError: Если размерности логитов не совпадают и не могут быть приведены к одинаковому размеру.      Examples:     ---------------         >>> student_logits = torch.randn(3, 512)         >>> teacher_logits = torch.randn(3, 510)         >>> pad_logits(student_logits, teacher_logits)         (tensor([...]), tensor([...]))     """     # Определение размеров логитов     student_size, teacher_size = student_logits.size(-1), teacher_logits.size(-1)      # Если размеры не совпадают, добавляем паддинг     if student_size != teacher_size:         pad_size = abs(student_size - teacher_size)         pad_tensor = torch.zeros(             (*teacher_logits.shape[:-1], pad_size),             dtype=teacher_logits.dtype,             device=teacher_logits.device         )          # Возвращаем логиты с добавленным паддингом         if student_size < teacher_size:             return torch.cat([student_logits, pad_tensor], dim=-1), teacher_logits         else:             return student_logits, torch.cat([teacher_logits, pad_tensor], dim=-1)      # Возвращаем логиты без изменений, если размеры совпадают     return student_logits, teacher_logits

5. Кастомный тренер для дистилляции

Для интеграции процесса дистилляции в процесс обучения создаётся специальный класс тренера, который переопределяет функцию вычисления потерь:

""" Этот класс: 1. Получает выходы (логиты) как от студента, так и от учителя 2. Замораживает веса учителя с помощью `torch.no_grad()` 3. Вычисляет комбинированную функцию потерь с использованием потерь от мягких и жестких меток """  from typing import Dict, Any, Union, Tuple import torch import torch.nn.functional as F from transformers import SFTTrainer  class LogitsTrainer(SFTTrainer):     """     Description:     ---------------         Класс для обучения модели с использованием дистилляции логитов.     """      def compute_loss(         self,         model: torch.nn.Module,         inputs: Dict[str, Any],         return_outputs: bool = False     ) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:         """         Description:         ---------------             Вычисляет комбинированную функцию потерь для модели студента и учителя.          Args:         ---------------             model: Модель студента.             inputs: Входные данные.             return_outputs: Флаг для возврата выходов модели.          Returns:         ---------------             Комбинированная функция потерь и, если указано, выходы модели.          Raises:         ---------------             ValueError: Если входные данные не соответствуют ожидаемым.          Examples:         ---------------             >>> model = ...             >>> inputs = ...             >>> trainer = LogitsTrainer()             >>> trainer.compute_loss(model, inputs, return_outputs=True)             (tensor(0.5), ...)         """         # Перемещение входных данных на устройство модели         inputs = {k: v.to(model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}          # Перемещение модели учителя на устройство модели         self.teacher_model = self.teacher_model.to(model.device)          # Получение модулей моделей, если они существуют         student_model = model.module if hasattr(model, 'module') else model         teacher_model = self.teacher_model.module if hasattr(self.teacher_model, 'module') else self.teacher_model          # Получение выходов моделей         student_outputs = student_model(**inputs)         with torch.no_grad():  # Учитель не обучается             teacher_outputs = teacher_model(**inputs)          # Вычисление комбинированной функции потерь         custom_loss = self.distillation_loss(             student_outputs.logits,             teacher_outputs.logits,             inputs,             student_outputs.loss         )          # Возврат потерь и выходов модели, если указано         if return_outputs:             return custom_loss, student_outputs         return custom_loss      def pad_logits(         self,         student_logits: torch.Tensor,         teacher_logits: torch.Tensor     ) -> Tuple[torch.Tensor, torch.Tensor]:         """         Description:         ---------------             Приводит размерности логитов студента и учителя к одинаковому размеру.          Args:         ---------------             student_logits: Логиты студента.             teacher_logits: Логиты учителя.          Returns:         ---------------             Кортеж из логитов студента и учителя с одинаковыми размерностями.          Raises:         ---------------             ValueError: Если размерности логитов не совпадают и не могут быть приведены к одинаковому размеру.          Examples:         ---------------             >>> student_logits = torch.randn(3, 512)             >>> teacher_logits = torch.randn(3, 510)             >>> trainer = LogitsTrainer()             >>> trainer.pad_logits(student_logits, teacher_logits)             (tensor([...]), tensor([...]))         """         # Определение размеров логитов         student_size, teacher_size = student_logits.size(-1), teacher_logits.size(-1)          # Если размеры не совпадают, добавляем паддинг         if student_size != teacher_size:             pad_size = abs(student_size - teacher_size)             pad_tensor = torch.zeros(                 (*teacher_logits.shape[:-1], pad_size),                 dtype=teacher_logits.dtype,                 device=teacher_logits.device             )              # Возвращаем логиты с добавленным паддингом             if student_size < teacher_size:                 return torch.cat([student_logits, pad_tensor], dim=-1), teacher_logits             else:                 return student_logits, torch.cat([teacher_logits, pad_tensor], dim=-1)          # Возвращаем логиты без изменений, если размеры совпадают         return student_logits, teacher_logits      def distillation_loss(         self,         student_logits: torch.Tensor,         teacher_logits: torch.Tensor,         inputs: Any,         original_loss: torch.Tensor     ) -> torch.Tensor:         """         Description:         ---------------             Вычисляет потери дистилляции между логитами студента и учителя.          Args:         ---------------             student_logits: Логиты студента.             teacher_logits: Логиты учителя.             inputs: Входные данные.             original_loss: Исходные потери.          Returns:         ---------------             Общие потери, включающие дистилляционные потери и исходные потери.          Raises:         ---------------             KeyError: Если в конфигурации отсутствуют необходимые ключи.          Examples:         ---------------             >>> config = {             ...     "distillation": {"temperature": 2.0, "alpha": 0.5},             ...     "tokenizer": {"max_length": 512}             ... }             >>> student_logits = torch.randn(3, 512)             >>> teacher_logits = torch.randn(3, 512)             >>> inputs = ...             >>> original_loss = torch.tensor(0.5)             >>> trainer = LogitsTrainer()             >>> trainer.distillation_loss(student_logits, teacher_logits, inputs, original_loss)             tensor(0.25)         """         # Приведение размерностей логитов учителя и студента к одинаковому размеру         student_logits, teacher_logits = self.pad_logits(             student_logits.to(self.model.device),             teacher_logits.to(self.model.device)         )          # Масштабирование логитов с помощью температуры T         temperature = config["distillation"]["temperature"]         student_logits_scaled = student_logits / temperature         teacher_logits_scaled = teacher_logits / temperature          # Расчёт KL-дивергенции между распределениями учителя и студента         loss_kd = F.kl_div(             F.log_softmax(student_logits_scaled, dim=-1),  # log(q_i^T)             F.softmax(teacher_logits_scaled, dim=-1),      # p_i^T             reduction='batchmean'         ) * (temperature ** 2) / config["tokenizer"]["max_length"]          # Комбинирование потерь от мягких и жестких меток         alpha = config["distillation"]["alpha"]         total_loss = alpha * loss_kd + (1 - alpha) * original_loss          return total_loss

6. Подготовка тренера и запуск обучения

После определения всех компонентов можно инициализировать тренер и запустить процесс дистилляции:

""" Обратите внимание, что модель-учитель добавляется к тренеру как атрибут, чтобы она была доступна внутри функции `compute_loss`. """  # Импорт необходимых библиотек from transformers import TrainingArguments from accelerate import Accelerator  # Инициализация accelerator accelerator = Accelerator()  # Аргументы обучения training_arguments = TrainingArguments(**config["training"])  # Проверка наличия предобработанного датасета if 'tokenized_dataset' not in locals():     # Если датасет не предобработан, выполняем необходимую предобработку     # Код предобработки датасета должен быть здесь...     print("Необходимо сначала выполнить предобработку датасета!")  # Создание кастомного SFT тренера trainer = LogitsTrainer(     model=student_model,     train_dataset=tokenized_dataset["train"],     eval_dataset=tokenized_dataset["test"],     tokenizer=student_tokenizer,     args=training_arguments,     max_seq_length=config["tokenizer"]["max_length"],     dataset_text_field="text", )  # Добавление модели-учителя к тренеру trainer.teacher_model = teacher_model  # Подготовка к распределенному обучению trainer = accelerator.prepare(trainer)  # Запуск обучения trainer.train(resume_from_checkpoint=config["training"]["resume_from_checkpoint"])  # Сохранение финальной модели trainer.save_model(config["training"]["output_dir"])  print(f"Обучение завершено. Модель сохранена в {config['training']['output_dir']}")

Преимущества Soft-label Distillation:

  • Более полная передача знаний: Студент получает доступ к «темным знаниям» учителя — информации о сложных случаях, тонких различиях между классами и степени неопределенности.

  • Лучшие результаты: Студенты, обученные этим методом, обычно демонстрируют производительность ближе к учителю по сравнению с Hard-label Distillation.

  • Улучшенная генерализация: Модели лучше работают на новых данных, так как учатся не только «что» предсказывать, но и «с какой уверенностью».

  • Контроль через температуру: Параметр T позволяет настраивать степень «мягкости» дистилляции. Более высокие значения T делают распределение более равномерным, помогая передать больше информации о маловероятных классах.

  • Совместимость с другими методами: Легко комбинируется с другими техниками улучшения моделей.

Недостатки Soft-label Distillation:

  • Вычислительные затраты: Для языковых моделей с большими словарями (50,000+ токенов) хранение и передача полных распределений вероятностей требует значительных ресурсов.

  • Сложность реализации: Требует доступа к логитам/вероятностям учителя, а не только к финальным предсказаниям.

  • Настройка гиперпараметров: Необходимо тщательно подбирать температуру T и коэффициент α для оптимальных результатов.

  • Зависимость от качества учителя: Если учитель имеет систематические ошибки, они могут передаться студенту.

Сравнение Hard-label и Soft-label Distillation:

Аспект

Hard-label Distillation

Soft-label Distillation

Передаваемая информация

Только итоговые классы/токены

Полные распределения вероятностей

Температура

Не используется

Используется для «смягчения» распределений

Сложность реализации

Простая

Средняя

Вычислительные требования

Низкие

Средние-высокие

Объем хранимых данных

Малый

Большой (особенно для языковых моделей)

Качество получаемой модели

Хорошее

Лучшее

Способность передавать неопределенность

Низкая

Высокая

Эффективность для языковых моделей

Средняя

Высокая

В заключение, Soft-label Distillation предлагает более мощный метод передачи знаний от учителя к ученику, особенно для сложных задач, где важны тонкие различия между классами и понимание неопределенности. Ключевое отличие от Hard-label Distillation заключается в использовании полных распределений вероятностей и temperature scaling, что позволяет извлечь «темные знания» и научить студента не только выдавать правильные ответы, но и воспроизводить тонкие нюансы рассуждений учителя.

Part 2: Законы масштабирования дистилляции

После того, как DeepSeek представил в open source свой метод дистилляции знаний для R1, исследователи из Apple и Оксфордского университета быстро предложили закон масштабирования дистилляции и уже 28 февраля завершили все эксперименты и загрузили 67-страничную статью на arXiv.

Рассмотрим мотивацию исследования, которая сводится к следующим пунктам:

  1. Текущее состояние исследований законов масштабирования моделей: В последние годы исследования выявили взаимосвязь между производительностью языковых моделей, их размером и объемом данных для обучения. Однако систематических исследований законов масштабирования в контексте дистилляции пока не проводилось.

  2. Проблема стоимости вывода модели: С увеличением размера языковых моделей значительно возрастает стоимость вывода. Исследование того, как снизить стоимость вывода без потери производительности, становится важной задачей.

  3. Эффективность и производительность дистилляции: Теоретически, дистилляция может снизить стоимость вывода, однако в академических кругах нет единого мнения относительно методов дистилляции, особенно в том, как рационально распределить вычислительные ресурсы для создания наиболее мощных моделей, что остается большой неопределенностью.

Экстраполяция закона масштабирования дистилляции

Экстраполяция закона масштабирования дистилляции

Экстраполяции закона масштабирования дистилляции. Закон масштабирования дистилляции (Уравнение 8) аппроксимирован на слабых учениках L_S > 2.3 для ряда учителей с потерями L_T . Сплошные линии представляют прогнозируемое поведение модели для невидимых учителей при заданной конфигурации ученика (интерполяция), а пунктирные линии представляют прогнозируемое поведение модели за пределами видимых учителей и для области сильных учеников ( L_S \leq 2.3 ).

Закон масштабирования дистилляции

Традиционный закон масштабирования (Scaling Laws) для больших моделей демонстрирует, что производительность языковой модели (LM) может улучшаться с увеличением вычислительных ресурсов, если модель следует оптимальной вычислительной парадигме обучения. Однако постоянный рост затрат на инференс делает этот подход все менее практичным, что заставляет исследователей искать альтернативные методы, включая переобучение и дистилляцию, для создания небольших, но мощных моделей.

Исследователи провели обширные эксперименты, используя модели-студенты и модели-учителя с параметрами от 143 миллионов до 12,6 миллиардов и объемом данных до 512 миллиардов токенов. Целью было изучить взаимосвязь между производительностью модели и вычислительными ресурсами в процессе дистилляции, а также найти способы оптимизации распределения этих ресурсов.

В следующей таблице показано значение символов, используемых в этой статье:

Таблица. Выражения, связанные с законами масштабирования, используемые в данной работе. В каждом случае S всегда относится к ученику, а не к обучению с учителем.

Выражение

Значение

N / N_S / N_T

Количество параметров модели/ученика/учителя, не связанных с эмбеддингом. В тексте, когда мы упоминаем параметры, мы всегда имеем в виду параметры, не связанные с эмбеддингом, если не указано иное. Подробности см. в Приложении H.2.

D / D_T

Количество токенов, на которых предобучена модель/учитель.

D_S

Количество токенов, на которых дистиллирован ученик.

M \equiv D / N

Соотношение токенов на параметр, или MM-соотношение. В работе Hoffmann et al. (2022), M принимает оптимальное значение M^∗≈20, что является эмпирическим правилом Chinchilla.

L \approx L(N, D)

Кросс-энтропия модели, которая представляет собой валидационную кросс-энтропию модели на данных, оцениваемую по закону масштабирования с учителем для модели с N параметрами, обученной на D токенах. (Уравнение 1).

L_T \approx L(N_T, D_T)

Кросс-энтропия учителя, которая представляет собой валидационную кросс-энтропию учителя на данных, оцениваемую по закону масштабирования с учителем для учителя с N_T​ параметрами, обученного на D_T​ токенах.

L_S \approx L_S(N_S, D_S, L_T)

Кросс-энтропия ученика, которая представляет собой валидационную кросс-энтропию ученика на данных, оцениваемую по нашему закону масштабирования дистилляции для ученика с N_S​ параметрами, дистиллированного на D_S​ токенах с использованием учителя с потерей предобучения L_T​ (Уравнение 8).

\tilde{L}_S \approx L(N_S, D_S)

Кросс-энтропия ученика с учителем, которая представляет собой валидационную кросс-энтропию ученика на данных, если бы ученик был обучен с учителем, оцениваемую по закону масштабирования с учителем для ученика с N_S​ параметрами, обученного на D_S​ токенах.

Пояснение: Кросс-энтропия — это метрика, измеряющая расхождение между предсказанным распределением вероятностей модели и истинным распределением. Чем ниже кросс-энтропия, тем лучше модель предсказывает правильные токены. Это основной показатель качества языковой модели.

Пояснение к правилу Чинчиллы: Исследование Hoffmann et al. (2022) установило эмпирическое правило оптимального соотношения между количеством параметров модели и количеством токенов для обучения — примерно 20 токенов на каждый параметр. Это правило позволяет эффективно распределять вычислительные ресурсы при обучении крупных языковых моделей.

Формализация закона масштабирования дистилляции

Центральным вкладом исследования является формулировка закона масштабирования дистилляции:

L_S(N_S, D_S, L_T) = L_T + \frac{1}{L_{c_0}^T} \left( 1 + \left( \frac{L_T}{\tilde{L}_S^{d_1}} \right)^{1/f_1} \right)^{-c_1f_1} \left( \frac{A}{N_S^{\alpha'}} + \frac{B}{D_S^{\beta'}} \right)^{\gamma'}

Объяснение переменных:

L_S(N_S, D_S, L_T)кросс-энтропия студента (мера ошибки предсказания; чем ниже, тем лучше модель).

L_Tкросс-энтропия учителя (мера ошибки предсказания большой модели).

N_Sколичество неэмбеддинговых параметров студента (основные обучаемые параметры модели).

D_Sколичество токенов, использованных для обучения студента при дистилляции.

\tilde{L}_S = L(N_S, D_S)потенциальная кросс-энтропия студента при обычном обучении без дистилляции, определяемая классическим законом масштабирования:

L(N, D) = E - \frac{A}{N^\alpha} - \frac{B}{D^\beta}

\{c_0, c_1, d_1, f_1, \alpha', \beta', \gamma'\}коэффициенты, определяемые эмпирически.

A иBположительные коэффициенты, зависящие от архитектуры модели и характеристик набора данных.

Физический смысл формулы:

1. Базовая часть:L_T — студент не может быть лучше учителя.

2. Модифицирующая часть: Остальная часть формулы описывает, насколько эффективно студент может приблизиться к учителю в зависимости от своего размера, количества данных и качества учителя.

Ключевые выводы:

1. Студент не может превзойти учителя (всегдаL_S \geq L_T). Кросс-энтропия (L) — это мера ошибки модели. Чем ниже значение L, тем лучше модель предсказывает данные.

2. Чем ближе потенциальная производительность студента к производительности учителя, тем эффективнее дистилляция.

3. При фиксированном учителе закон масштабирования дистилляции не превосходит обычный закон масштабирования.

Практическое применение:

Этот закон позволяет оптимально распределить вычислительные ресурсы между учителем и студентом и прогнозировать эффективность дистилляции.

То есть: Этот закон описывает, как качество маленькой модели зависит от трех факторов: размера самой модели, количества данных для обучения и качества большой модели-учителя. Ключевой вывод: студент никогда не может быть лучше учителя, но насколько близко он подойдет к учителю, зависит от его собственных возможностей и объема тренировки.

Коэффициенты смешивания в процессе дистилляции знаний

Рассмотрев общий закон масштабирования дистилляции, важно также понять практические аспекты реализации этого процесса, в частности, как управлять балансом между имитацией учителя и самостоятельным обучением модели-ученика.

Основная идея дистилляции знаний заключается в переносе информации от большой модели-учителя к компактной модели-ученику. В этом процессе прогнозируемое распределение вероятностей модели-учителя используется в качестве целевого распределения для модели-ученика. Обучение происходит путем минимизации расхождения Кульбака-Лейблера (KL-дивергенции) между распределениями ученика и учителя:

{L}_{\text{KD}} \left( z_T^{(i)}, z_S^{(i)} \right) = -\tau^2 \sum_{a=1}^V \sigma_a \left( \frac{z_T^{(i)}}{\tau} \right) \log \sigma_a \left( \frac{z_S^{(i)}}{\tau} \right)

где:

z_T^{(i)} иz_S^{(i)} — выходные логиты моделей учителя и ученика соответственно

\tau — температура дистилляции, контролирующая «сглаженность» распределения вероятностей учителя

\sigma_a — функция softmax, преобразующая логиты в вероятности

V — размер словаря

Комбинированная функция потерь для модели-ученика объединяет несколько компонентов:

{L}_S\big(x^{(i)}, \boldsymbol{z}_T^{(i)},\boldsymbol{z}_S^{(i)}\big) = (1-\lambda)\,{L}_{\textrm{NTP}}(x^{(i)},\boldsymbol{z}_S^{(i)}) + \lambda\,{L}_{\textrm{KD}}(\boldsymbol{z}_T^{(i)},\boldsymbol{z}_S^{(i)}) + \lambda_Z\,{L}_Z(\boldsymbol{z}_S^{(i)})

где:

{L}_{\textrm{NTP}} — потеря при предсказании следующего токена (стандартная кросс-энтропия)

{L}_{\textrm{KD}}— потеря при дистилляции знаний (KL-дивергенция)

{L}_Z— регуляризационная Z-потеря, стабилизирующая обучение путем нормализации логитов

\lambda— коэффициент смешивания, определяющий баланс между обучением на «чистых» данных и имитацией учителя

\lambda_Z— весовой коэффициент для Z-потери

Экспериментальное определение оптимальных параметров дистилляции

Для определения влияния параметров дистилляции на эффективность закона масштабирования, исследователи провели серию экспериментов. Чтобы исключить влияние данных и сосредоточиться именно на роли модели-учителя, эксперименты проводились в режиме «чистой дистилляции» с λ=1. Результаты показали, что такой выбор λ даёт результаты, статистически сопоставимые с использованием оптимальных значений λ^∗.

Во всех экспериментах использовалась фиксированная температура дистилляции τ=1, которая эмпирически показала наилучшую эффективность для обучения модели-ученика.

Коэффициенты смешивания λ

Коэффициенты смешивания λ

Коэффициенты смешивания\lambda.

(a) Модели-ученики шести размеров N_S \in \{198M, 266M, \ldots, 2.72B\}, обученные с соотношением M = D_S/N_S = 20, дистиллируются от моделей-учителей размеров N_T \in \{546M, 975M, \ldots, 7.75B\}, обученных с соотношениеM = D_T/N_T = 20, с различными значениями коэффициента смешивания \lambda \in [0, 1]. Значения\lambda = 0 и\lambda = 1 соответствуют стандартному обучению и чистой дистилляции соответственно.

(b) Оптимальные коэффициенты смешивания\lambda^* = \arg \min_{\lambda} {L}(\lambda), дающие наименьшую потерю на валидационном наборе для каждой пары учитель-ученик.

Эти эксперименты подтверждают, что параметры дистилляции оказывают существенное влияние на итоговую производительность модели-ученика, и их оптимальный выбор напрямую связан с размерами моделей учителя и ученика, что согласуется с общим законом масштабирования дистилляции.

Вывод

Дистилляция знаний — это метод, позволяющий передать способности большой нейронной модели (учителя) меньшей и вычислительно эффективной модели (ученику). Процесс основан на обучении модели-ученика имитировать распределение вероятностей модели-учителя путём минимизации расхождения Кульбака-Лейблера между их предсказаниями.

Эффективность дистилляции определяется балансом нескольких компонентов в функции потерь:

  • Стандартной кросс-энтропии при предсказании следующего токена

  • KL-дивергенции при имитации учителя

  • Регуляризационной Z-потери для стабилизации обучения

Два ключевых параметра контролируют этот процесс:

  • Коэффициент смешивания λ, регулирующий баланс между самостоятельным обучением и имитацией учителя

  • Температура дистилляции τ, влияющая на «сглаженность» распределения вероятностей

Экспериментальные исследования демонстрируют, что режим «чистой дистилляции» (λ = 1) при температуре τ = 1 часто даёт результаты, сопоставимые с оптимально подобранными параметрами. Однако наиболее важным открытием является то, что идеальные значения этих параметров системно зависят от соотношения размеров конкретной пары моделей учитель-ученик.

Это открытие соответствует общему закону масштабирования дистилляции и имеет прямое практическое применение: для достижения максимальной эффективности при практической реализации дистилляции необходим индивидуальный подбор параметров с учётом размеров используемых моделей, что позволяет существенно улучшить итоговую производительность компактной модели при сохранении её вычислительной эффективности.

Эксперимент с фиксированным учитилем и разными учениками

Размер модели учителя и объем обучающих данных на которых обучался учитель, фиксированы, а размер модели ученика и объем дистилляционных данных варьируются. Цель состоит в том, чтобы изучить, как производительность модели ученика меняется в зависимости от ее размера и объема обработанных дистилляционных данных в условиях фиксированной модели учителя. Таким образом, можно определить оптимальную производительность модели студента при различных масштабах и объемах данных.

Figure_5

Figure_5
Figure_6

Figure_6

Из результатов эксперимента можно заметить, что:

  • При высокой вычислительной мощности, чем больше масштаб параметров модели ученика, тем меньше его функция потерь, и чем больше масштаб модели учителя, тем очевиднее эта тенденция.

  • Когда размер моделей ученика и учителя определен, становится понятно, что чем больше вычислительная мощность, тем лучше будет работать модель ученика.

  • При низкой вычислительной мощности производительность модели сначала улучшится, а затем ослабнет с размером модели. Здесь легко понять, что более крупные модели не полностью обучаются при меньшей вычислительной мощности.

  • В особых случаях модель ученика может превзойти модель учителя и показать способность к обобщению. Я лично предполагаю, что модель учителя может быть недообучена в таких сценариях.

Эксперимент с фиксированным учеником и разными учителями

Размер модели ученика и объем данных дистилляции фиксированы, а размер модели учителя и объем обучающих данных варьируются. Цель состоит в том, чтобы изучить, как эффективность модели учителя влияет на конечную эффективность модели ученика. Таким образом, можно определить оптимальный размер модели учителя и объем обучающих данных для максимизации производительности модели ученика.

Figure_7

Figure_7

Как видно из результатов, чем больше параметры у модели учителя, тем ниже перекрестная энтропия модели ученика. Это показывает, что для достижения наилучшего эффекта дистилляции производительность модели учителя должна соответствовать возможностям модели ученика.

Дистилляция против контролируемого обучения

Чтобы понять, когда дистилляция приносит пользу, на следующем рисунке сравнивается производительность дистилляции и контролируемого обучения при фиксированных вычислительных ресурсах. Результаты показывают, что контролируемое обучение всегда превосходит дистилляцию при наличии достаточного количества вычислений или данных у учащихся. При умеренном бюджете данных дистилляция имеет преимущества, однако при наличии больших объемов данных контролируемое обучение превосходит дистилляцию.

Подводя итог, можно сказать, что при ограниченных вычислительных ресурсах дистилляция обычно более эффективна, чем контролируемое обучение. Это связано с тем, что дистилляция может быстрее усваивать эффективные представления признаков под руководством модели учителя, тем самым достигая более высокой производительности при меньших вычислительных ресурсах.

Figure_8

Figure_8

Выбор модели учителя

  • Сила обучающего сигнала: Модели учителей разных размеров могут обеспечивать разную силу обучающего сигнала, которая обычно измеряется с помощью потери перекрестной энтропии. Более крупная модель учителя может обеспечить более сильный сигнал обучения (более низкая перекрестная энтропия), тем самым помогая модели ученика лучше учиться.

  • Увеличение затрат: использование более крупной модели учителя повлечет за собой более высокие затраты из-за необходимости вычисления логитов модели учителя. Это означает, что более крупная модель учителя не только более затратна в обучении, но и потребляет больше вычислительных ресурсов при использовании для дистилляции.

На рисунке ниже показано изменение потери перекрестной энтропии модели студента при различных бюджетах данных дистилляции. Результаты показывают, что оптимальная потеря учителя (представленная красной линией) уменьшается по степенному закону с увеличением численности учащихся до тех пор, пока потеря ученика не сравняется с оптимальной потерей учителя.

Figure_9

Figure_9

Как видно на другом рисунке ниже, по мере увеличения объема данных дистилляции, перекрестная энтропия оптимальной модели учителя постепенно уменьшается. Таким образом, можно сделать вывод, что: когда вычислительные ресурсы ограничены, выбор меньшей модели учителя может снизить затраты на вывод, при этом обеспечивая эффективные сигналы обучения для модели ученика.

Figure_10

Figure_10

Рассчитайте оптимальную дистилляцию

Целью вычислительно оптимальной дистилляции является определение способа создания модели студента желаемого размера с наименьшей перекрестной энтропией при заданном вычислительном бюджете. В частности, необходимо найти оптимальный объем данных для обучения учащихся, размер модели учителя, данные для обучения, чтобы минимизировать перекрестную энтропию студента и при этом удовлетворить ограничениям вычислительного бюджета.

На рисунке ниже мы видим:

  • Контролируемое обучение всегда соответствует наилучшему варианту настройки дистилляции при достаточном вычислительном бюджете: Контролируемое обучение всегда соответствует наилучшему варианту настройки дистилляции при определенном общем вычислительном бюджете. Это означает, что контролируемое обучение может достичь той же производительности, что и дистилляция, если вычислительный бюджет достаточно велик.

  • Если в вычисления включено обучение учителя, перекрестная энтропия учащихся всегда выше, чем в контролируемой обстановке: это означает, что если вашей единственной целью является создание наилучшей модели с целевым размером и у вас нет доступа к учителю, вам следует выбрать контролируемое обучение вместо обучения учителя и последующей дистилляции. Напротив, если цель состоит в том, чтобы выделить семейство моделей или использовать учителя в качестве обслуживающей модели, то выделение может оказаться более выгодным с вычислительной точки зрения, чем контролируемое обучение.

  • Меньшие модели с большей вероятностью получат выгоду от контролируемого предварительного обучения, в то время как более крупные модели с большей вероятностью получат выгоду от дистилляции: Меньшие модели с большей вероятностью получат выгоду от контролируемого обучения при больших вычислительных бюджетах, в то время как более крупные модели с большей вероятностью получат выгоду от дистилляции при больших вычислительных бюджетах.

Figure_11

Figure_11

На рисунке ниже показаны тенденции изменения оптимального размера учителя и объема обучающих данных по мере изменения вычислительного бюджета. Токены моделей студентов и преподавателей масштабируются по степенному закону, причем токены студентов растут быстрее. Размер лучшей модели учителя сначала увеличивается, пока не станет немного больше ученика, а затем стабилизируется. Это связано с тем, что использование большой модели учителя для вывода обходится дорого, и по мере увеличения количества токенов учеников более эффективным становится переобучение модели учителя.

Figure_12

Figure_12

Ключевые результаты исследования

В результате исследований авторы пришли к следующим выводам:

  1. Предсказуемость производительности через закон масштабирования: Производительность модели-студента размером N_S​, полученной путем дистилляции из модели-учителя размером N_T​ с использованием D_S​ токенов, может быть предсказана с помощью разработанного закона масштабирования дистилляции.

    Практическое значение: Это позволяет заранее оценить, какой результат можно получить от процесса дистилляции, не проводя дорогостоящих экспериментов. Компания может спланировать свои ресурсы и решить, стоит ли вкладываться в дистилляцию, или лучше выбрать другой подход к созданию эффективной модели.

  2. Влияние параметров учителя на студента: Размер модели-учителя NTNT​ и количество токенов для её обучения D_T​ определяют кросс-энтропию модели-учителя L_T=L_T(N_T,D_T), которая, в свою очередь, влияет на кросс-энтропию модели-студента.

    Наглядный пример: Представьте учителя как источник знаний для студента. Если учитель сам недостаточно образован (высокая кросс-энтропия), он не сможет хорошо обучить студента, какими бы способностями студент ни обладал.

  3. Феномен «разрыва в способностях»: Исследование выявило интересный эффект — более сильный учитель может привести к худшему студенту, что объясняется «разрывом в способностях» (capacity gap). Влияние кросс-энтропии модели-учителя на потери модели-студента следует степенному закону, который переключается между двумя режимами в зависимости от относительной способности к обучению студента и учителя. Исследование показало, что важен именно разрыв в способности к обучению (гипотезное пространство и оптимизационная способность) между учителем и студентом, а не просто их относительный размер.

    Аналогия для понимания: Представьте, что профессор квантовой физики пытается обучить первоклассника. Несмотря на высокую квалификацию профессора, первоклассник не сможет усвоить сложный материал из-за разрыва в способностях к обучению. Аналогично, если модель-учитель слишком сложна и «мыслит» на уровне, недоступном модели-студенту, эффективность обучения снижается.

  4. U-образная зависимость ошибки студента: Эмпирически подтверждается U-образная зависимость ошибки студента от размера учителя при фиксированном размере студента, что теоретически обосновывается разрывом в емкости между ними.

    Визуальное представление: Если изобразить ошибку студента на графике, где по горизонтальной оси отложен размер учителя, мы увидим U-образную кривую. Это означает, что существует оптимальный размер учителя для данного студента — не слишком маленький (недостаточно знаний) и не слишком большой (слишком сложное представление знаний).

Практические рекомендации

Результаты исследования показывают, что дистилляция становится более эффективной, чем обучение с учителем, при соблюдении следующих условий:

  1. Общее количество вычислений или токенов для студента не превышает пороговое значение, связанное с размером студента, согласно новому закону масштабирования.

    Практический сценарий: Для компании с ограниченным бюджетом на вычисления, которая хочет создать модель размером 1 миллиард параметров, дистилляция может быть оптимальным выбором, если доступно менее 20 миллиардов токенов для обучения (согласно правилу Чинчиллы).

  2. Модель-учитель уже существует, или обучение модели-учителя имеет применение за пределами одной дистилляции.

    Бизнес-кейс: Если компания уже обучила крупную модель для своих основных задач, имеет смысл использовать её для дистилляции меньших, специализированных моделей для развертывания на мобильных устройствах или в средах с ограниченными ресурсами.


🔥Не пропустите важные обновления и углубленные материалы!🔥  

Хотите быть в курсе самых свежих обзоров и исследований в мире ML и AI? Переходите по ссылкам ниже, чтобы получить доступ к эксклюзивному контенту:  

📌 Все обзоры также доступны в нашем Telegram канале TheWeeklyBrief📢

📌 Более подробный обзор с математической формализацией и программным кодом ждет вас в нашем репозитории Weekly-arXiv-ML-AI-Research-Review 👩‍💻📂✨  

Не упустите шанс глубже погрузиться в мир технологий! 🚀


ссылка на оригинал статьи https://habr.com/ru/articles/891284/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *