Кастомные loss-функции в TensorFlow/Keras и PyTorch

от автора

Привет, Хабр!

Стандартные loss‑функции, такие как MSE или CrossEntropy, хороши, но часто им не хватает гибкости для сложных задач. Допустим, есть тот же проект с огромным дисбалансом классов, или хочется внедрить специфическую регуляризацию прямо в функцию потерь. Стандартный функционал тут бессилен — тут на помощь приходят кастомные loss’ы.

Custom Loss Functions в TensorFlow/Keras

TensorFlow/Keras радуют удобным API, но за простоту приходится платить вниманием к деталям.

Focal Loss

Focal Loss помогает сместить фокус обучения на сложные примеры, снижая влияние легко классифицируемых данных:

import tensorflow as tf from tensorflow.keras import backend as K  def focal_loss(gamma=2., alpha=0.25):     """     Реализация Focal Loss для задач с дисбалансом классов.     :param gamma: фокусирующий параметр для усиления влияния сложных примеров.     :param alpha: коэффициент балансировки классов.     :return: функция потерь, принимающая (y_true, y_pred).     """     def focal_loss_fixed(y_true, y_pred):         # Защита от log(0) – обрезаем значения предсказаний.         y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())         # Вычисляем кросс-энтропию для каждого примера.         cross_entropy = -y_true * tf.math.log(y_pred)         # Применяем вес для "тяжёлых" примеров.         weight = alpha * tf.pow(1 - y_pred, gamma)         loss = weight * cross_entropy         # Усредняем по батчу и классам.         return tf.reduce_mean(tf.reduce_sum(loss, axis=-1))     return focal_loss_fixed  # Пример использования Focal Loss: if __name__ == "__main__":     # Тестовые данные для отладки (да, я тоже люблю маленькие эксперименты)     y_true = tf.constant([[1, 0], [0, 1]], dtype=tf.float32)     y_pred = tf.constant([[0.9, 0.1], [0.2, 0.8]], dtype=tf.float32)          loss_fn = focal_loss(gamma=2.0, alpha=0.25)     loss_value = loss_fn(y_true, y_pred)     print("Focal Loss:", loss_value.numpy())

Интеграция кастомного loss в модель Keras

Создадим простую CNN‑модель для распознавания изображений и подключим Focal Loss:

import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense  def create_model(input_shape=(28, 28, 1), num_classes=10):     model = Sequential([         Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape),         MaxPooling2D(pool_size=(2, 2)),         Flatten(),         Dense(128, activation='relu'),         Dense(num_classes, activation='softmax')     ])     return model  # Компилируем модель с кастомной функцией потерь model = create_model() model.compile(optimizer='adam', loss=focal_loss(gamma=2.0, alpha=0.25), metrics=['accuracy'])  # Создадим тестовые данные (набор из случайных изображений и меток) import numpy as np X_train = np.random.rand(100, 28, 28, 1) y_train = tf.keras.utils.to_categorical(np.random.randint(0, 10, 100), num_classes=10)  print("Запускаем обучение модели с кастомным Focal Loss...") model.fit(X_train, y_train, epochs=3, batch_size=16)

Модель обучается и градиенты сходятся.

Нюансы вычисления градиентов

Нельзя забывать — любые операции, выполняемые с numpy, ломают автоматическое вычисление градиентов. Пример плохой практики:

import numpy as np import tensorflow as tf  def loss_with_numpy(y_true, y_pred):     # Плохая практика: переводим тензоры в numpy и разрываем градиентный поток.     y_true_np = y_true.numpy()  # Ой-ой, ошибка внутри GradientTape!     y_pred_np = y_pred.numpy()     loss_np = np.mean((y_true_np - y_pred_np) ** 2)     return tf.constant(loss_np, dtype=tf.float32)  if __name__ == "__main__":     x = tf.constant([[1.0], [2.0]])     y_true = tf.constant([[1.5], [2.5]])          with tf.GradientTape() as tape:         tape.watch(x)         y_pred = x * 2         try:             loss = loss_with_numpy(y_true, y_pred)             grad = tape.gradient(loss, x)             print("Gradient:", grad)         except Exception as e:             print("Ошибка при вычислении градиента:", e)

Оставайтесь в мире тензоров — TensorFlow умеет всё, что нужно, если вы не решите подмешать туда numpy.

Custom Loss Functions в PyTorch

Реализация кастомной loss через torch.autograd.Function

Начнем с простейшей реализации кастомной loss‑функции, которая считает квадратичную ошибку:

import torch  class CustomLossFunction(torch.autograd.Function):     @staticmethod     def forward(ctx, input, target):         """         Прямой проход: вычисляем MSE.         """         ctx.save_for_backward(input, target)         loss = torch.mean((input - target) ** 2)         return loss      @staticmethod     def backward(ctx, grad_output):         """         Обратный проход: аккуратно считаем градиенты.         """         input, target = ctx.saved_tensors         grad_input = grad_output * 2 * (input - target) / input.numel()         return grad_input, None  # Тестовый пример использования: if __name__ == "__main__":     x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)     y = torch.tensor([1.5, 2.5, 3.5])          loss = CustomLossFunction.apply(x, y)     print("Custom Loss (PyTorch):", loss.item())          loss.backward()     print("Gradient (PyTorch):", x.grad)

Focal Loss в PyTorch

Focal Loss существует не только в TensorFlow. В PyTorch можно сделать не хуже:

import torch import torch.nn as nn import torch.nn.functional as F  class FocalLoss(nn.Module):     def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):         super(FocalLoss, self).__init__()         self.alpha = alpha         self.gamma = gamma         self.reduction = reduction      def forward(self, inputs, targets):         # Если inputs – логиты, используем sigmoid для преобразования         BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')         pt = torch.exp(-BCE_loss)         focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss                  if self.reduction == 'mean':             return focal_loss.mean()         elif self.reduction == 'sum':             return focal_loss.sum()         else:             return focal_loss  # Тестируем Focal Loss в PyTorch: if __name__ == "__main__":     inputs = torch.tensor([[0.2, -1.0], [1.5, 0.3]], requires_grad=True)     targets = torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)          criterion = FocalLoss(alpha=0.25, gamma=2.0)     loss = criterion(inputs, targets)     print("Focal Loss (PyTorch):", loss.item())          loss.backward()     print("Gradients (Focal Loss):", inputs.grad)

Работа с эмбеддингами

Для задач, где нужно сравнивать схожесть объектов, подойдут Contrastive и Triplet Loss. Реализуем их в PyTorch.

Contrastive Loss

class ContrastiveLoss(nn.Module):     def __init__(self, margin=1.0):         super(ContrastiveLoss, self).__init__()         self.margin = margin      def forward(self, output1, output2, label):         # Евклидова дистанция между эмбеддингами         euclidean_distance = F.pairwise_distance(output1, output2)         loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +                                       (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))         return loss_contrastive  # Пример использования Contrastive Loss: if __name__ == "__main__":     output1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)     output2 = torch.tensor([[1.5, 2.5], [2.5, 3.5]], requires_grad=True)     # label: 0 для похожих пар, 1 для непохожих.     label = torch.tensor([0, 1], dtype=torch.float32)          criterion = ContrastiveLoss(margin=1.0)     loss = criterion(output1, output2, label)     print("Contrastive Loss:", loss.item())     loss.backward()

Triplet Loss

class TripletLoss(nn.Module):     def __init__(self, margin=1.0):         super(TripletLoss, self).__init__()         self.margin = margin      def forward(self, anchor, positive, negative):         pos_distance = F.pairwise_distance(anchor, positive, p=2)         neg_distance = F.pairwise_distance(anchor, negative, p=2)         losses = torch.relu(pos_distance - neg_distance + self.margin)         return losses.mean()  # Пример использования Triplet Loss: if __name__ == "__main__":     anchor = torch.tensor([[1.0, 2.0], [2.0, 3.0]], requires_grad=True)     positive = torch.tensor([[1.1, 2.1], [1.9, 2.9]], requires_grad=True)     negative = torch.tensor([[3.0, 4.0], [4.0, 5.0]], requires_grad=True)          criterion = TripletLoss(margin=1.0)     loss = criterion(anchor, positive, negative)     print("Triplet Loss:", loss.item())     loss.backward()

Если вам хочется поделиться опытом — пишите в комментариях.

Все актуальные методы и инструменты DS и ML можно освоить на онлайн-курсах OTUS: в каталоге можно посмотреть список всех программ, а в календаре — записаться на открытые уроки.


ссылка на оригинал статьи https://habr.com/ru/articles/892462/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *