RecBole — «комбайн» на PyTorch для любых рекомендаций

от автора

Привет, Хабр!

Сегодня разберём RecBole — универсальный фреймворк на PyTorch, который отвечает на три насущных вопроса любого ML-инженера рекомендаций:

  • Как быстро обкатать десятки алгоритмов (от классического MF до SASRec и KGAT) на собственном датасете — без сотни скриптов?

  • Как хранить все настройки в одном YAML, а не в трёх сотнях аргументов CLI?

  • Как получить честное сравнение метрик и сразу вынести лучший чекпоинт в прод?

Рассмотрим подробнее под катом.

Установка и подготовка данных

pip install recbole>=1.2 python -m recbole.quick_start.run_recbole --model=BPR --dataset=ml-1m

У RecBole есть встроенная заготовка датасетов (ml-1m, yelp, amazon-*). Свой датасет кидаем в папку dataset/<name>/ в формате Atomic Files:

файл

обязательные поля

комментарий

<name>.inter

user_id, item_id, rating, timestamp

минимум две первые колонки

<name>.item

item_id, genre, year, …

любые side-фичи

<name>.user

user_id, age, city, …

optional

Parquet читается быстрее, но RecBole «проглатывает» и CSV.

Автоматический сплит

В recbole.yaml достаточно:

split_ratio: [0.8, 0.1, 0.1]   # train/valid/test group_by_user: True            # чтобы у каждого юзера были все статусы

Всё, никаких ручных датафреймов на pandas.

Разбираемся с API

run_recbole

from recbole.quick_start import run_recbole  run_recbole(     model='LightGCN',          # любая из 90+ моделей     dataset='ml-1m',           # или путь к своему набору     config_dict={              # приоритет над YAML и CLI         'epochs': 50,         'topk': 10,         'neg_sampling': {'uniform': 1},         'seed': 42,            # чтобы метрики не «плавали»     } )

Что происходит под капотом

Шаг

Вызов

Под капотом

Конфигурация

Config

собирает всё из recbole.yaml, аргументов CLI и config_dict, давая приоритет последнему. Можно вызвать config.save() и получить итоговый YAML для репродюса.

Дата

create_dataset

читает Atomic Files, авто-инференсит типы полей (int/float/token/sequence), пишет мета-JSON в processed/*.json.

Семплеры

create_sampler

строит Sampler (point-wise, pair-wise, full-sort). Хотите динамический негатив — передайте neg_sampling.dynamic: 1 и получите новый семплер без правки кода.

Лоадеры

create_dataloader

лениво подгружает батчи; для огромных данных ставьте lazy_loading: True, чтобы не держать всё в памяти.

Модель

Model

вытягивается рефлексией из recbole.model. Хотите кастом — наследуйтесь от BaseModel, регайте через register_model.

Тренер

Trainer

инициализирует оптимизатор/скедьюлер, early-stopping, логгер. Для knowledge distillation есть KnowledgeDistillationTrainer.

Эвалар

Evaluator

считает HR@K, NDCG@K, MRR, MAP; full_sort_topk ранжирует весь каталог, а не sampled-негативы.

Вывод

~

сохраняет лучший чек-пойнт + лог в /saved/LightGCN-<timestamp>/

Хотите логировать в W&B — добавьте wandb: True в YAML. Нужен mixed-precision — train_stage: fp16. Гиперпараметры через CLI: python run_recbole.py --learning_rate=5e-4 --dropout_prob=0.3.

Гранулярный контроль

Иногда однострочник — роскошь, и нужен доступ к каждому объекту. Тогда:

from recbole.config import Config from recbole.data import create_dataset, data_preparation from recbole.utils import init_seed from recbole.model.general_recommender import LightGCN from recbole.trainer import Trainer  # 1. Конфиг из файла + CLI config = Config(model='LightGCN', dataset='ml-1m')      # читает recbole.yaml config['epochs'] = 30                                   # оверрайд «на лету»  # 2. Dataset init_seed(config['seed']) dataset = create_dataset(config)                        # <RecDataset 1 1000209>  # 3. Sampler / Dataloader train_data, valid_data, test_data = data_preparation(config, dataset)  # 4. Модель model = LightGCN(config, dataset).to(config['device'])  # 5. Тренер trainer = Trainer(config, model) best_valid_score, best_valid_result = trainer.fit(     train_data, valid_data, saved=True, show_progress=True)  score, result = trainer.evaluate(test_data, load_best_model=True) print(result)     # {'Recall@10': 0.1627, 'NDCG@10': 0.0894, ...}

Config

# recbole.yaml (кусочек) MODEL_TYPE: Sequential     # автоматически подскажет, что у модели есть max_seq_length epochs: 40 neg_sampling:   dynamic: 1 eval_args:   mode: full                # full-sort evaluation   order: RO                 # рейтинг -> онлайн   split: {'RS': [0.8,0.1,0.1]} checkpoint_dir: ./saved/ wandb: True 

Переопределение при импорте:

cfg = Config(model='SASRec', dataset='ml-1m',              config_dict={'epochs': 10, 'dropout_prob': 0.2})

Доступ к параметрам — по ключу: cfg['topk'], cfg.final_config_dict — готовый словарь для логирования.

Dataset и друзья

from recbole.data import Dataset dset = Dataset(config)      # наследник torch.utils.data.Dataset len(dset.field2type)        # {'user_id': 'token', 'item_id': 'token', ...}
  • Custom поля — добавьте колонку в .inter и опишите тип в YAML:

    FIELD_TYPES: {'price': float, 'brand': token}
  • Sequence -> unrolled. Для последовательных моделей (SASRec, GRU4Rec) RecBole сам создаёт hist_seq и target_item.

  • Lazy loading для >10 GB дат:

    lazy_loading: True

Самплеры и “кормушки”

from recbole.data import (     create_samplers, create_dataloader, data_preparation )  samplers = create_samplers(config, dataset)       # TrainSampler / FullSortSampler train_loader, valid_loader, test_loader = create_dataloader(     config, dataset, samplers)
  • Популярный негатив:

    neg_sampling:   popularity: 1

    Внутри PopularitySampler — item-frequency softmax.

  • Dynamic Sampler считает свежие негативы каждую эпоху, спасая от информации-leakage.

Пишем свою модель

from recbole.model.abstract_recommender import GeneralRecommender from recbole.model.loss import BPRLoss import torch.nn as nn import torch  class MyDotMF(GeneralRecommender):     def __init__(self, config, dataset):         super().__init__(config, dataset)         self.embedding_size = config['embedding_size']         self.user_embedding = nn.Embedding(             dataset.num(self.USER_ID), self.embedding_size)         self.item_embedding = nn.Embedding(             dataset.num(self.ITEM_ID), self.embedding_size)         self.loss_fct = BPRLoss()      def forward(self, interaction):         user = interaction[self.USER_ID]         pos_item = interaction[self.ITEM_ID]         user_e = self.user_embedding(user)         item_e = self.item_embedding(pos_item)         scores = (user_e * item_e).sum(-1)         return scores      def calculate_loss(self, interaction):         pos_score = self.forward(interaction)         neg_items = interaction[self.NEG_ITEM_ID]         neg_e = self.item_embedding(neg_items)         neg_score = (user_e.unsqueeze(1) * neg_e).sum(-1)         return self.loss_fct(pos_score, neg_score)

Регистрируем:

from recbole.utils import register_model register_model('MyDotMF', MyDotMF)

Теперь в YAML достаточно model: MyDotMF.

Минимальный кейс:

mkdir -p dataset/shop python - <<'PY' import pandas as pd, pyarrow.parquet as pq df = pq.read_table('orders.parquet').to_pandas() df[['user_id','sku','ts']].to_csv(     'dataset/shop/shop.inter', sep='\t', index=False) PY

recbole.yaml:

field_separator: "\t" USER_ID_FIELD: user_id ITEM_ID_FIELD: sku TIME_FIELD: ts  model: SASRec epochs: 20 learning_rate: 1e-3 neg_sampling: ~ LABEL_FIELD: click topk: 20 metrics: ['Recall', 'NDCG', 'MRR'] device: cuda

Запуск:

from recbole.quick_start import run_recbole run_recbole(dataset='shop')

RecBole сам сделает сплит, залиогирует Recall@20, сохранит чек-пойнт и итоговый YAML в saved/SASRec-shop-<ts>/.

Фичи

Правильный neg_sampling — бесплатный буст к NDCG

neg_sampling:   uniform: 1 # или   popularity: 1 # или   dynamic: 1            # поддерживается с v1.2

dynamic может давать +5 % NDCG@10 vs uniform.

Knowledge-Graph модели

Если берёте KGAT/CFKG/TransRec, добавляйте файл графа:

knowledge_graph_file: shop.kg

Формат тривиальный: head relation tail. RecBole сам построит adjacency matrix.

GPU-OOM ловушка

Параметр train_batch_size умноженный на количество GPU → ваша фактическая матрица эмбеддингов. Когда загоняете SASRec на A100 40 GB, не забывайте, что скрытая матрица self-attention растёт квадратично от max_seq_length.

train_batch_size: 512     # ок max_seq_length: 200       # ок n_layers: 4               # ок

Уехали в 1024×512×6 — здравствуй, CUDA OOM.

Экспорт в прод

torch.save(model.state_dict(), 'lightgcn.pt') # inference model = LightGCN(config, dataset) model.load_state_dict(torch.load('lightgcn.pt', map_location='cpu')) model.eval()

Никаких RecBole-зависимостей в рантайме: чистый PyTorch внутри Docker.

Итоги

RecBole закрывает 80 % типовых задач ресёрча и «ML-прототипов» в одном пакете: вам остаётся только решать, какую модель кормить продакшену. Да, бывают кейсы, где нужен Sparkили multi-tower архитектура под рекламу – тогда пляшем руками. Но для большинства продуктовых рекомендателей «поднять бейзлайн» быстрее RecBole сегодня мало что умеет.


Если вы работаете с рекомендательными системами или только собираетесь внедрять их в продукт, обязательно загляните в RecBole — мощный фреймворк на PyTorch, который закрывает до 80% задач ресёрча и ML‑прототипирования «из коробки». Поддержка 90+ моделей, единый YAML для всей конфигурации, автоматическая обработка данных, гибкий negative sampling и честные метрики — всё это помогает не тратить время на рутину и быстрее выходить в прод.

Чтобы разобраться в возможностях RecBole и не тратить недели на документацию, присоединяйтесь к нашему циклу открытых уроков:

Каждый урок — это практическое погружение: от запуска бейзлайна на своём датасете до кастомизации модели и экспорта в inference. Присоединяйтесь — и проверьте на практике, насколько RecBole может упростить вашу работу.


ссылка на оригинал статьи https://habr.com/ru/articles/924426/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *