
Введение
Машинное обучение уже везде и, пожалуй, почти невозможно найти софт, не использующий его прямо или косвенно. Давайте создадим небольшое приложение, способное загружать изображения на сервер для последующего распознавания с помощью ML. А после сделаем их доступными через мобильное приложение с текстовым поиском по содержимому.
Мы будем использовать Flask для нашего REST API, Flutter для мобильного приложения и Keras для машинного обучения. В качестве базы данных для хранения информации о содержимом изображений используем MongoDB, а для получения информации возьмём уже натренированную модель ResNet50. При необходимости мы сможем заменить модель, используя методы save_model() и load_model(), доступные в Keras. Последний потребует около 100 Мб при первоначальной загрузке модели. Почитать о других доступных моделях можно в документации.
Начнём с Flask
Если вы незнакомы с Flask, то создать роут на нём можно просто добавив к контроллеру декоратор app.route(‘/’), где app — переменная приложения. Пример:
from flask import Flask app = Flask(__name__) @app.route('/') def hello_world(): return 'Hello, World!'
При запуске и переходе по дефолтному адресу 127.0.0.1:5000/ мы увидим ответ Hello World! О том, как сделать что-то посложнее, можно почитать в документации.
Приступим же к созданию полноценного бэкенда:
import os import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image as img from keras.preprocessing.image import img_to_array import numpy as np from PIL import Image from keras.applications.resnet50 import ResNet50,decode_predictions,preprocess_input from datetime import datetime import io from flask import Flask,Blueprint,request,render_template,jsonify from modules.dataBase import collection as db
Как можно заметить импорты содержат tensorflow, который мы будем использовать как бэкенд для keras, а так же numpy для работы с мультиразмерными массивами.
mod = Blueprint('backend', __name__, template_folder='templates', static_folder='./static') UPLOAD_URL = 'http://192.168.1.103:5000/static/' model = ResNet50(weights='imagenet') model._make_predict_function()
На первой строчке мы создаём блюпринт для более удобной организации приложения. Из-за этого надо будет использовать mod.route(‘/’) для декорирования контроллера. Предварительно натренированная на imagenet модель Resnet50 нуждается в вызове _make_predict_function() для инициализации. Без этого шага есть вероятность получить ошибку. А другую модель можно использовать, заменив строку
model = ResNet50(weights='imagenet')
на
model = load_model('saved_model.h5')
Вот как будет выглядеть контроллер:
@mod.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': # проверяем, что прислали файл if 'file' not in request.files: return "someting went wrong 1" user_file = request.files['file'] temp = request.files['file'] if user_file.filename == '': return "file name not found ..." else: path = os.path.join(os.getcwd()+'\\modules\\static\\'+user_file.filename) user_file.save(path) classes = identifyImage(path) db.addNewImage( user_file.filename, classes[0][0][1], str(classes[0][0][2]), datetime.now(), UPLOAD_URL+user_file.filename) return jsonify({ "status":"success", "prediction":classes[0][0][1], "confidence":str(classes[0][0][2]), "upload_time":datetime.now() })
В коде выше загруженное изображение передаётся в метод identifyImage(file_path), который реализован так:
def identifyImage(img_path): image = img.load_img(img_path, target_size=(224,224)) x = img_to_array(image) x = np.expand_dims(x, axis=0) # images = np.vstack([x]) x = preprocess_input(x) preds = model.predict(x) preds = decode_predictions(preds, top=1) print(preds) return preds
Сначала мы преобразуем изображение к размеру 224*224, т.к. именно он нужен нашей модели. Затем передаём в model.predict() предварительно обработанные байты изображения. Теперь наша модель может предсказать, что находится на изображении (top=1 нужен чтобы получить единственный самый вероятный результат).
Сохраним полученные данные о содержимом изображения в MongoDB с помощью функции db.addData(). Вот релевантная часть кода:
from pymongo import MongoClient from bson import ObjectId client = MongoClient("mongodb://localhost:27017") # host uri db = client.image_predition #Select the database image_details = db.imageData def addNewImage(i_name, prediction, conf, time, url): image_details.insert({ "file_name":i_name, "prediction":prediction, "confidence":conf, "upload_time":time, "url":url }) def getAllImages(): data = image_details.find() return data
Так как мы использовали блюпринт, код для API можно разместить в отдельном файле:
from flask import Flask,render_template,jsonify,Blueprint mod = Blueprint('api',__name__,template_folder='templates') from modules.dataBase import collection as db from bson.json_util import dumps @mod.route('/') def api(): return dumps(db.getAllImages())
Как можно заметить, для возвращения данных БД мы используем json. Посмотреть на результат можно по адресу 127.0.0.1:5000/api
Выше, разумеется, только самые важные куски кода. Полностью проект можно посмотреть в GitHub репозитории. А больше о Pymongo можно почитать здесь.
Создаём приложение Flutter
Мобильная версия будет получать изображения и данные об их содержимом по REST API. Вот что получится в итоге:

ImageData класс инкапсулирует данные об изображении:
import 'dart:convert'; import 'package:http/http.dart' as http; import 'dart:async'; class ImageData { // static String BASE_URL ='http://192.168.1.103:5000/'; String uri; String prediction; ImageData(this.uri,this.prediction); } Future<List<ImageData>> LoadImages() async { List<ImageData> list; //complete fetch .... var data = await http.get( 'http://192.168.1.103:5000/api/'); var jsondata = json.decode(data.body); List<ImageData> newslist = []; for (var data in jsondata) { ImageData n = ImageData(data['url'],data['prediction']); newslist.add(n); } return newslist; }
Здесь мы получаем json, преобразуем его в список объектов ImageData и возвращаем во Future Builder с помощью функции LoadImages()
Загрузка изображений на сервер
uploadImageToServer(File imageFile)async { print("attempting to connecto server......"); var stream = new http.ByteStream(DelegatingStream.typed(imageFile.openRead())); var length = await imageFile.length(); print(length); var uri = Uri.parse('http://192.168.1.103:5000/predict'); print("connection established."); var request = new http.MultipartRequest("POST", uri); var multipartFile = new http.MultipartFile('file', stream, length, filename: basename(imageFile.path)); //contentType: new MediaType('image', 'png')); request.files.add(multipartFile); var response = await request.send(); print(response.statusCode); }
Чтобы сделать Flask доступным в локальной сети отключите режим дебага и найдите ipv4 адрес, используя ipconfig. Запустить локальный сервер можно так:
app.run(debug=False, host='192.168.1.103', port=5000)
Иногда файрвол может мешать приложению обращаться к локалхосту, тогда его придётся перенастроить или отключить.
Весь исходный код приложения доступен на гитхабе. Вот ссылки, которые помогут разобраться в происходящем:
Keras : https://keras.io/
Flutter : https://flutter.dev/
MongoDB : https://www.tutorialspoint.com/mongodb/
Курс Гарварда по Python и Flask: https://www.youtube.com/watch?v=j5wysXqaIV8&t=5515s (особенно важны лекции 2,3,4)
GitHub : https://github.com/SHARONZACHARIA
ссылка на оригинал статьи https://habr.com/ru/post/460995/
Добавить комментарий