Привет! Меня зовут Никита Грибков, я Flutter-разработчик в AGIMA. В этой статье расскажу про фреймворк TensorFlow Lite, который позволяет интегрировать в мобильное приложение модели машинного обучения. Это полезная штука, если нужно реализовать фичи, связанные с распознаванием речи или с классификацией изображений. Покажу, как обучать модели и как затем с ними работать.
Технология позволяет создавать персонализированные и интеллектуальные решения для пользователей, поэтому пользуется высоким спросом. Если наша цель — сделать приложение более удобным и инклюзивным, то, скорее всего, придется использовать ML.
Вот несколько примеров задач, для которых технология 100% подходит:
-
классификация изображений: чтобы приложение могло распознавать объекты на фотографиях или видео (например, Google Lens);
-
обработка естественного языка (NLP): в приложениях с голосовыми ассистентами или чат-ботами ML обрабатывает речь и тексты (например, Siri или Google Assistant);
-
персонализация: алгоритмы ML анализируют поведение пользователей и предлагают персонализированный контент или рекомендации;
-
распознавание голоса: используется в приложениях для конвертации речи в текст и команд.
Существует несколько способов, как интегрировать модели машинного обучения в приложение. Можно воспользоваться ML Kit от Firebase или библиотеки на Dart. Но лично я пробовал работать с TensorFlow Lite (TFLite). Этот фреймворк можно считать самым распространенным решением в данном случае.
Его главное (но не единственное) преимущество — что он может работать в офлайне, когда устройство не подключено к интернету. Также мне нравится, что TFLite оптимизирован для работы на устройствах с ограниченными ресурсами, это удобно. Разберем, как фреймворк работает.
Подготовка модели для использования с TFLite
Прежде чем интегрировать TFLite во Flutter-приложение, необходимо подготовить модель. Это предполагает её обучение в TensorFlow и конвертацию в формат .tflite.
Шаг 1. Создание и обучение модели в TensorFlow
Для работы с машинным обучением вы можете обучить модель с помощью TensorFlow. Вот простой пример создания и обучения модели на Python:
import tensorflow as tf from tensorflow.keras import layers # Создание простой модели для классификации изображений model = tf.keras.Sequential([ layers.Flatten(input_shape=(28, 28)), layers.Dense(128, activation='relu'), layers.Dense(10, activation='softmax') ])
Компиляция модели:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
Обучение модели на данных MNIST:
model.fit(train_images, train_labels, epochs=5)
Сохранение модели:
model.save("model.h5")
Сеть состоит из одного слоя для преобразования 28×28 пикселей в одномерный вектор, скрытого слоя с 128 нейронами и выходного слоя с 10 нейронами для 10 классов.
model = tf.keras.Sequential([ layers.Flatten(input_shape=(28, 28)), layers.Dense(128, activation='relu'), layers.Dense(10, activation='softmax') ])
Модель компилируется с использованием оптимизатора Adam и функции потерь Sparse Categorical Crossentropy.
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
Затем обучается на данных MNIST в течение 5 эпох и сохраняется в файл «model.h5».
model.fit(train_images, train_labels, epochs=5)
Шаг 2: Конвертация модели в формат TFLite
После обучения модели ее нужно преобразовать в формат .tflite с помощью TFLite-конвертера.
Пример кода для конвертации модели:
# Загрузка модели model = tf.keras.models.load_model('model.h5') # Конвертация модели в формат TFLite converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() # Сохранение модели в формате .tflite with open('model.tflite', 'wb') as f: f.write(tflite_model)
Теперь у вас есть модель в формате .tflite, которую можно интегрировать в приложение на Flutter.
Интеграция TFLite в Flutter-приложение
Для работы с TFLite в Flutter нужно использовать плагин tflite_flutter
. Этот репозиторий — управляемый TensorFlow форк проекта — управляемый TensorFlow форк проекта [tflite_flutter_plugin]
Шаг 1. Установка необходимых зависимостей
Откройте файл pubspec.yaml вашего Flutterпроекта и добавьте зависимости:
dependencies: flutter: sdk: flutter tflite_flutter: ^0.11.0 tflite_flutter_helper_plus: ^0.0.2
Шаг 2. Подготовка модели
Скопируйте файл вашей модели model.tflite в папку проекта assets. Затем в файле pubspec.yaml укажите путь к модели в разделе assets:
flutter: assets: - assets/model.tflite - assets/labels.txt # если у вас есть файл с метками
Шаг 3. Загрузка и использование модели в коде Flutter
Теперь создадим код для загрузки модели и выполнения предсказаний на ее основе на стороне Flutter.
Импорт пакетов:
import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter_helper_plus/tflite_flutter_helper_plus.dart';
Загрузка модели:
late Interpreter interpreter; Future<void> loadModel() async { try { // Загружаем модель из assets interpreter = await Interpreter.fromAsset('model.tflite'); print('Модель загружена успешно'); } catch (e) { print('Ошибка загрузки модели: $e'); } }
Этот код преобразует изображение в массив Float32List. Он берет каждый пиксель изображения, извлекает значения красного, зеленого и синего каналов, нормализует их с помощью заданных mean и std, а затем заполняет массив.
Float32List imageToByteListFloat32( img.Image image, int inputSize, double mean, double std) { var convertedBytes = Float32List(1 * inputSize * inputSize * 3); var buffer = Float32List.view(convertedBytes.buffer); int pixelIndex = 0; for (var i = 0; i < inputSize; i++) { for (var j = 0; j < inputSize; j++) { var pixel = image.getPixel(j, i); buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std); buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std); buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std); } } return convertedBytes; }
Выполнение предсказаний
Для выполнения предсказаний нужно преобразовать входные данные в подходящий формат, например, изображение в тензор (массив данных).
Future<void> classifyImage(File image) async { // Преобразуем изображение в тензор final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!; var inputImage = img.copyResize(imageInput, width: 28, height: 28); var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5); // Подготовка выходного тензора var output = List.filled(10, 0).reshape([1, 10]); // Выполнение предсказания _interpreter.run(input, output); setState(() { _result = 'Предсказание: ${output.toString()}'; }); }
Flutter с использованием TFLite
import 'dart:typed_data'; import 'package:flutter/material.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:image_picker/image_picker.dart'; import 'dart:io'; import 'package:image/image.dart' as img; class MyHomePage extends StatefulWidget { @override _MyHomePageState createState() => _MyHomePageState(); } class _MyHomePageState extends State<MyHomePage> { late Interpreter _interpreter; File? _image; final picker = ImagePicker(); String _result = 'Нет предсказаний'; @override void initState() { super.initState(); loadModel(); } Future<void> loadModel() async { try { _interpreter = await Interpreter.fromAsset('model.tflite'); print('Модель загружена'); } catch (e) { print('Ошибка загрузки модели: $e'); } } Future<void> pickImage() async { final pickedFile = await picker.pickImage(source: ImageSource.gallery); setState(() { _image = File(pickedFile!.path); }); if (_image != null) { classifyImage(_image!); } } Future<void> classifyImage(File image) async { // Преобразуем изображение в тензор final img.Image imageInput = img.decodeImage(image.readAsBytesSync())!; var inputImage = img.copyResize(imageInput, width: 28, height: 28); var input = imageToByteListFloat32(inputImage, 28, 127.5, 127.5); // Подготовка выходного тензора var output = List.filled(10, 0).reshape([1, 10]); // Выполнение предсказания _interpreter.run(input, output); setState(() { _result = 'Предсказание: ${output.toString()}'; }); } Float32List imageToByteListFloat32( img.Image image, int inputSize, double mean, double std) { var convertedBytes = Float32List(1 * inputSize * inputSize * 3); var buffer = Float32List.view(convertedBytes.buffer); int pixelIndex = 0; for (var i = 0; i < inputSize; i++) { for (var j = 0; j < inputSize; j++) { var pixel = image.getPixel(j, i); buffer[pixelIndex++] = ((img.getRed(pixel) - mean) / std); buffer[pixelIndex++] = ((img.getGreen(pixel) - mean) / std); buffer[pixelIndex++] = ((img.getBlue(pixel) - mean) / std); } } return convertedBytes; } @override Widget build(BuildContext context) { return Scaffold( appBar: AppBar(title: Text('TFLite Classifier')), body: Column( children: [ _image == null ? Text('Выберите изображение') : Image.file(_image!), ElevatedButton( onPressed: pickImage, child: Text('Загрузить изображение'), ), Text(_result), ], ), ); } }
Оптимизация модели для мобильных устройств
Чтобы повысить производительность на мобильных устройствах, можно использовать такие подходы:
-
Квантизация модели. Она уменьшает размер модели и ускоряет работу за счет уменьшения точности числовых представлений.
-
Параллельное выполнение. Использование многоядерных процессоров для ускорения предсказаний.
Пример кода для квантизации модели:
converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() with open('quantized_model.tflite', 'wb') as f: f.write(tflite_model)
Что в итоге
В итоге мы получаем приложение с уже работающими моделями ML. Самый долгий этап связан с обучением моделей, всё остальное — вопрос техники. Думаю, примеры выше помогут провести интеграцию быстро.
Если у вас остались вопросы — задавайте в комментариях, я отвечу. А вообще подписывайтесь на канал нашего коллеги Саши Ворожищева — он много пишет про Flutter и про мобильную разработку в целом.
Что еще почитать
ссылка на оригинал статьи https://habr.com/ru/articles/852500/
Добавить комментарий