Tribuo и регрессия: как строить предсказательные модели на Java

от автора

Привет, Хабр!

В этой статье наш взор упадет на на Tribuo — библиотеку машинного обучения на Java от Oracle.

Tribuo поддерживает различные алгоритмы для классификации, регрессии, кластеризации и многого другого. Но сегодня мы сосредоточимся на регрессии — фундаментальной задаче, которая позволяет предсказывать непрерывные значения. Одним из главных плюсов Tribuo является её удобный API, который позволяет быстро строить модели и оценивать их эффективность.

Установка и настройка проекта

Начнём с самого начала — настройки проекта. Для работы с Tribuo понадобится Java 11 или новее. Также будем использовать Maven для управления зависимостями. Если вы используете другую систему сборки, то принципы останутся теми же, только синтаксис зависимостей будет другим.

Добавим Tribuo в наш pom.xml:

<dependencies>     <!-- Tribuo Core -->     <dependency>         <groupId>org.tribuo</groupId>         <artifactId>tribuo-all</artifactId>         <version>4.3.0</version>     </dependency>     <!-- Для работы с CSV -->     <dependency>         <groupId>org.tribuo</groupId>         <artifactId>tribuo-data</artifactId>         <version>4.3.0</version>     </dependency>     <!-- Для линейной регрессии -->     <dependency>         <groupId>org.tribuo</groupId>         <artifactId>tribuo-regression-linear</artifactId>         <version>4.3.0</version>     </dependency>     <!-- Для CART регрессии -->     <dependency>         <groupId>org.tribuo</groupId>         <artifactId>tribuo-regression-cart</artifactId>         <version>4.3.0</version>     </dependency> </dependencies>

Обновляем зависимости и все готово!

Подготовка данных

Работа с данными — это первый и один из самых важных шагов в машинном обучении. Для примера возьмём простой CSV-файл с данными о ценах домов:

square_feet,rooms,has_garage,price 1200,3,1,250000 1400,3,0,300000 1600,4,1,350000 ...

Здесь:

  • square_feet — площадь дома в квадратных футах,

  • rooms — количество комнат,

  • has_garage — наличие гаража (1 — есть, 0 — нет),

  • price — цена дома (целевая переменная).

Используем CSVLoader для загрузки данных:

import org.tribuo.data.csv.CSVLoader; import org.tribuo.data.csv.CSVDataSource; import org.tribuo.regression.Regressor; import org.tribuo.regression.RegressorFactory;  import java.nio.file.Paths;  public class DataLoader {     public static CSVDataSource<Regressor> loadData(String filePath, String targetColumn) throws IOException {         CSVLoader<Regressor> loader = new CSVLoader<>(new RegressorFactory());         return loader.loadData(Paths.get(filePath), targetColumn);     } }

После этого нужно разделить данные, чтобы оценить модель на невидимых данных. Используем TrainTestSplitter:

import org.tribuo.data.split.TrainTestSplitter; import org.tribuo.regression.Regressor; import org.tribuo.data.dataset.Dataset;  public class DataSplitter {     public static TrainTestSplitter<Regressor> splitData(Dataset<Regressor> data, double trainFraction, long seed) {         return new TrainTestSplitter<>(data, trainFraction, seed);     } }

trainFraction — доля данных для обучения (например, 0.7 для 70%), а seed — случайное зерно для воспроизводимости.

Построение регрессионной модели

Теперь перейдем к созданию моделей. Начнём с простой линейной регрессии.

import org.tribuo.Model; import org.tribuo.regression.Regressor; import org.tribuo.regression.linear.LinearRegressionTrainer;  public class LinearRegressionModel {     public static Model<Regressor> trainModel(Dataset<Regressor> trainData) {         LinearRegressionTrainer trainer = new LinearRegressionTrainer(0.01, LinearRegressionTrainer.LossType.SQUARED);         return trainer.train(trainData);     } }

LinearRegressionTrainer — тренер для линейной регрессии. LossType.SQUARED — тип функции потерь (в данном случае квадратичная).

Для более сложных задач можно использовать CARTRegressionTrainer, который реализует алгоритм случайных лесов.

import org.tribuo.regression.cart.CARTRegressionTrainer;  public class RandomForestRegressionModel {     public static Model trainModel(Dataset trainData, int numTrees) {         CARTRegressionTrainer trainer = new CARTRegressionTrainer(numTrees);         return trainer.train(trainData);     } }

numTrees — количество деревьев в лесу.

Оценка модели

После обучения модели важно понять, насколько она хороша. Для этого используем RegressionEvaluator.

import org.tribuo.regression.RegressionEvaluator; import org.tribuo.Model; import org.tribuo.regression.Regressor;  public class ModelEvaluator {     public static void evaluateModel(Model model, Dataset testData) {         RegressionEvaluator evaluator = new RegressionEvaluator();         var evaluation = evaluator.evaluate(model, testData);         System.out.println(evaluation);     } }

evaluate — метод, который принимает модель и тестовые данные, возвращает метрики.

Пример прогнозирования цен на дома

Соберём всё вместе и создадим полный пример.

import org.tribuo.Model; import org.tribuo.data.csv.CSVDataSource; import org.tribuo.data.dataset.Dataset; import org.tribuo.data.split.TrainTestSplitter; import org.tribuo.regression.Regressor; import org.tribuo.regression.RegressionEvaluator; import org.tribuo.regression.linear.LinearRegressionTrainer; import org.tribuo.regression.cart.CARTRegressionTrainer;  import java.io.IOException;  public class HousePricePrediction {     public static void main(String[] args) {         try {             // Шаг 1: Загрузка данных             CSVDataSource dataSource = DataLoader.loadData("house_prices.csv", "price");             Dataset data = dataSource.getDataset();              // Шаг 2: Разделение данных             TrainTestSplitter splitter = DataSplitter.splitData(data, 0.7, 42L);             Dataset trainData = splitter.getTrainingDataset();             Dataset testData = splitter.getTestDataset();              // Шаг 3: Обучение линейной регрессии             Model linearModel = LinearRegressionModel.trainModel(trainData);             System.out.println("Линейная регрессия обучена!");              // Шаг 4: Обучение Random Forest регрессии             Model rfModel = RandomForestRegressionModel.trainModel(trainData, 100);             System.out.println("Random Forest регрессия обучена!");              // Шаг 5: Оценка моделей             System.out.println("Оценка линейной регрессии:");             ModelEvaluator.evaluateModel(linearModel, testData);              System.out.println("Оценка Random Forest регрессии:");             ModelEvaluator.evaluateModel(rfModel, testData);          } catch (IOException e) {             System.err.println("Ошибка при загрузке данных: " + e.getMessage());         }     } }

Параметры тренера

Tribuo позволяет настраивать множество параметров тренера. Например, для линейной регрессии можно настроить:

  • Learning Rate (коэффициент обучения): Влияет на скорость сходимости.

  • Loss Type (тип функции потерь): Помимо квадратичной, доступны другие типы, такие как абсолютная ошибка.

Пример с другими параметрами:

LinearRegressionTrainer trainer = new LinearRegressionTrainer(     0.05, // Коэффициент обучения     LinearRegressionTrainer.LossType.HUBER, // Функция потерь Хьюбера     1.0 // Параметр delta для функции Хьюбера );

Как избежать переобучения

Кросс-валидация — мощный инструмент для оценки модели. Внедрим 5-кратную кросс-валидацию.

import org.tribuo.data.cross.CrossValidator; import org.tribuo.data.cross.CrossValidationResult; import org.tribuo.regression.Regressor;  public class CrossValidationExample {     public static void performCrossValidation(Dataset data, int folds) {         LinearRegressionTrainer trainer = new LinearRegressionTrainer(0.01, LinearRegressionTrainer.LossType.SQUARED);         CrossValidator crossValidator = new CrossValidator&lt;&gt;(trainer, folds);         CrossValidationResult result = crossValidator.evaluate(data);         System.out.println(result);     } }

folds — количество разбиений (например, 5 для 5-кратной кросс-валидации). evaluate — метод, который выполняет кросс-валидацию и возвращает результат.

Как подобрать гиперпараметры

Подбор гиперпараметров может значительно бустануть качество модели. Рассмотрим простой пример перебора параметров с использованием Grid Search.

import org.tribuo.regression.linear.LinearRegressionTrainer; import org.tribuo.Model; import org.tribuo.regression.Regressor;  public class HyperparameterTuning {     public static void gridSearch(Dataset trainData, Dataset testData) {         double[] learningRates = {0.001, 0.01, 0.1};         LinearRegressionTrainer.LossType[] lossTypes = {             LinearRegressionTrainer.LossType.SQUARED,             LinearRegressionTrainer.LossType.HUBER         };          for (double lr : learningRates) {             for (LinearRegressionTrainer.LossType lt : lossTypes) {                 LinearRegressionTrainer trainer = new LinearRegressionTrainer(lr, lt);                 Model model = trainer.train(trainData);                 RegressionEvaluator evaluator = new RegressionEvaluator();                 var evaluation = evaluator.evaluate(model, testData);                 System.out.println("Learning Rate: " + lr + ", Loss Type: " + lt);                 System.out.println(evaluation);             }         }     } }

Перебираем различные значения learning rate и loss type, после чего обучаем модель для каждой комбинации и оцениваем её на тестовых данных.

Визуализация результатов

В самом Tribuo нет инструментов для визуализации, но можно экспортировать результаты и использовать, например, JFreeChart или JavaFX для построения графиков.

Простой пример экспорта:

import java.io.FileWriter; import java.io.IOException; import java.util.List;  public class PredictionExporter {     public static void exportPredictions(Model model, Dataset testData, String outputPath) throws IOException {         try (FileWriter writer = new FileWriter(outputPath)) {             writer.append("Actual,Predicted\n");             for (var example : testData) {                 double actual = example.getOutput().getOutput();                 double predicted = model.predict(example).getOutput();                 writer.append(actual + "," + predicted + "\n");             }         }     } }

Подробнее с Tribuo можно ознакомиться здесь.


Все актуальные методы и инструменты DS и ML можно освоить на онлайн-курсах OTUS: в каталоге можно посмотреть список всех программ, а в календаре — записаться на открытые уроки.

Один из уроков пройдет 11 ноября и будет посвящен теме «Временные ряды Фурье и вейвлет-анализ». Этот урок будет особенно интересен ML-инженерам, которые начинают знакомство с временными рядами и хотят выйти за границы модели SARIMA. Подробнее


ссылка на оригинал статьи https://habr.com/ru/articles/853300/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *