Интеграция TFLite во Flutter: внедряем модели машинного обучения в мобильное приложение

от автора

Привет! Меня зовут Никита Грибков, я Flutter-разработчик в AGIMA. В этой статье расскажу про фреймворк TensorFlow Lite, который позволяет интегрировать в мобильное приложение модели машинного обучения. Это полезная штука, если нужно реализовать фичи, связанные с распознаванием речи или с классификацией изображений. Покажу, как обучать модели и как затем с ними работать.

Технология позволяет создавать персонализированные и интеллектуальные решения для пользователей, поэтому пользуется высоким спросом. Если наша цель — сделать приложение более удобным и инклюзивным, то, скорее всего, придется использовать ML.

Вот несколько примеров задач, для которых технология 100% подходит:

  • классификация изображений: чтобы приложение могло распознавать объекты на фотографиях или видео (например, Google Lens);

  • обработка естественного языка (NLP): в приложениях с голосовыми ассистентами или чат-ботами ML обрабатывает речь и тексты (например, Siri или Google Assistant);

  • персонализация: алгоритмы ML анализируют поведение пользователей и предлагают персонализированный контент или рекомендации;

  • распознавание голоса: используется в приложениях для конвертации речи в текст и команд.

Существует несколько способов, как интегрировать модели машинного обучения в приложение. Можно воспользоваться ML Kit от Firebase или библиотеки на Dart. Но лично я пробовал работать с TensorFlow Lite (TFLite). Этот фреймворк можно считать самым распространенным решением в данном случае.

Его главное (но не единственное) преимущество — что он может работать в офлайне, когда устройство не подключено к интернету. Также мне нравится, что TFLite оптимизирован для работы на устройствах с ограниченными ресурсами, это удобно. Разберем, как фреймворк работает.

Подготовка модели для использования с TFLite

Прежде чем интегрировать TFLite во Flutter-приложение, необходимо подготовить модель. Это предполагает её обучение в TensorFlow и конвертацию в формат .tflite.

Шаг 1. Создание и обучение модели в TensorFlow

Для работы с машинным обучением вы можете обучить модель с помощью TensorFlow. Вот простой пример создания и обучения модели на Python:

import tensorflow as tf from tensorflow.keras import layers  # Создание простой модели для классификации изображений model = tf.keras.Sequential([     layers.Flatten(input_shape=(28, 28)),     layers.Dense(128, activation='relu'),     layers.Dense(10, activation='softmax') ])

Компиляция модели:

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

Обучение модели на данных MNIST:

model.fit(train_images, train_labels, epochs=5)

Сохранение модели:

model.save("model.h5")

Сеть состоит из одного слоя для преобразования 28×28 пикселей в одномерный вектор, скрытого слоя с 128 нейронами и выходного слоя с 10 нейронами для 10 классов.

model = tf.keras.Sequential([     layers.Flatten(input_shape=(28, 28)),     layers.Dense(128, activation='relu'),     layers.Dense(10, activation='softmax') ])

Модель компилируется с использованием оптимизатора Adam и функции потерь Sparse Categorical Crossentropy.

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

Затем обучается на данных MNIST в течение 5 эпох и сохраняется в файл «model.h5».

model.fit(train_images, train_labels, epochs=5)

Шаг 2: Конвертация модели в формат TFLite

После обучения модели ее нужно преобразовать в формат .tflite с помощью TFLite-конвертера.

Пример кода для конвертации модели:

# Загрузка модели model = tf.keras.models.load_model('model.h5')  # Конвертация модели в формат TFLite converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert()  # Сохранение модели в формате .tflite with open('model.tflite', 'wb') as f:     f.write(tflite_model)

Теперь у вас есть модель в формате .tflite, которую можно интегрировать в приложение на Flutter.

Интеграция TFLite в Flutter-приложение

Для работы с TFLite в Flutter нужно использовать плагин tflite_flutter. Этот репозиторий — управляемый TensorFlow форк проекта — управляемый TensorFlow форк проекта [tflite_flutter_plugin]

Шаг 1. Установка необходимых зависимостей

Откройте файл pubspec.yaml вашего Flutterпроекта и добавьте зависимости:

dependencies:   flutter:     sdk: flutter   tflite_flutter: ^0.11.0   tflite_flutter_helper_plus: ^0.0.2

Шаг 2. Подготовка модели

Скопируйте файл вашей модели model.tflite в папку проекта assets. Затем в файле pubspec.yaml укажите путь к модели в разделе assets:

flutter:   assets:     - assets/model.tflite     - assets/labels.txt # если у вас есть файл с метками

Шаг 3. Загрузка и использование модели в коде Flutter

Теперь создадим код для загрузки модели и выполнения предсказаний на ее основе на стороне Flutter.

Импорт пакетов:

import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter_helper_plus/tflite_flutter_helper_plus.dart';

Загрузка модели:

late Interpreter interpreter; Future<void> loadModel() async {   try {     // Загружаем модель из assets     interpreter = await Interpreter.fromAsset('model.tflite');     print('Модель загружена успешно');   } catch (e) {     print('Ошибка загрузки модели: $e');   } }

Этот код преобразует изображение в массив Float32List. Он берет каждый пиксель изображения, извлекает значения красного, зеленого и синего каналов, нормализует их с помощью заданных mean и std, а затем заполняет массив.

Float32List imageToByteListFloat32(       img.Image image, int inputSize, double mean, double std) {     var convertedBytes = Float32List(1 * inputSize * inputSize * 3);     var buffer = Float32List.view(convertedBytes.buffer);     int pixelIndex = 0;     for (var i = 0; i < inputSize; i++) {       for (var j = 0; j < inputSize; j++) {         var pixel = image.getPixel(j, i);         buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std);         buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std);         buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std);       }     }     return convertedBytes;   }

Выполнение предсказаний

Для выполнения предсказаний нужно преобразовать входные данные в подходящий формат, например, изображение в тензор (массив данных).

 Future<void> classifyImage(File image) async {     // Преобразуем изображение в тензор     final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!;     var inputImage = img.copyResize(imageInput, width: 28, height: 28);     var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5);      // Подготовка выходного тензора     var output = List.filled(10, 0).reshape([1, 10]);      // Выполнение предсказания     _interpreter.run(input, output);      setState(() {       _result = 'Предсказание: ${output.toString()}';     });   }

Flutter с использованием TFLite

import 'dart:typed_data'; import 'package:flutter/material.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:image_picker/image_picker.dart'; import 'dart:io'; import 'package:image/image.dart' as img;  class MyHomePage extends StatefulWidget {   @override   _MyHomePageState createState() => _MyHomePageState(); }  class _MyHomePageState extends State<MyHomePage> {   late Interpreter _interpreter;   File? _image;   final picker = ImagePicker();   String _result = 'Нет предсказаний';    @override   void initState() {     super.initState();     loadModel();   }    Future<void> loadModel() async {     try {       _interpreter = await Interpreter.fromAsset('model.tflite');       print('Модель загружена');     } catch (e) {       print('Ошибка загрузки модели: $e');     }   }    Future<void> pickImage() async {     final pickedFile = await picker.pickImage(source: ImageSource.gallery);      setState(() {       _image = File(pickedFile!.path);     });      if (_image != null) {       classifyImage(_image!);     }   }    Future<void> classifyImage(File image) async {     // Преобразуем изображение в тензор     final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!;     var inputImage = img.copyResize(imageInput, width: 28, height: 28);     var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5);      // Подготовка выходного тензора     var output = List.filled(10, 0).reshape([1, 10]);      // Выполнение предсказания     _interpreter.run(input, output);      setState(() {       _result = 'Предсказание: ${output.toString()}';     });   }    Float32List imageToByteListFloat32(       img.Image image, int inputSize, double mean, double std) {     var convertedBytes = Float32List(1 * inputSize * inputSize * 3);     var buffer = Float32List.view(convertedBytes.buffer);     int pixelIndex = 0;     for (var i = 0; i < inputSize; i++) {       for (var j = 0; j < inputSize; j++) {         var pixel = image.getPixel(j, i);         buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std);         buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std);         buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std);       }     }     return convertedBytes;   }    @override   Widget build(BuildContext context) {     return Scaffold(       appBar: AppBar(title: Text('TFLite Classifier')),       body: Column(         children: [           _image == null ? Text('Выберите изображение') : Image.file(_image!),           ElevatedButton(             onPressed: pickImage,             child: Text('Загрузить изображение'),           ),           Text(_result),         ],       ),     );   } }

Оптимизация модели для мобильных устройств

Чтобы повысить производительность на мобильных устройствах, можно использовать такие подходы:

  • Квантизация модели. Она уменьшает размер модели и ускоряет работу за счет уменьшения точности числовых представлений.

  • Параллельное выполнение. Использование многоядерных процессоров для ускорения предсказаний.

Пример кода для квантизации модели:

converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert()  with open('quantized_model.tflite', 'wb') as f:     f.write(tflite_model) 

Что в итоге

В итоге мы получаем приложение с уже работающими моделями ML. Самый долгий этап связан с обучением моделей, всё остальное — вопрос техники. Думаю, примеры выше помогут провести интеграцию быстро.

Если у вас остались вопросы — задавайте в комментариях, я отвечу. А вообще подписывайтесь на канал нашего коллеги Саши Ворожищева — он много пишет про Flutter и про мобильную разработку в целом.

Что еще почитать


ссылка на оригинал статьи https://habr.com/ru/articles/852500/


Комментарии

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *