Сегодня в ленте была очередная порция хайпа про ИИ. Смешно читать про «аппаратное ускорение AI» на пользовательских устройствах. Автор, вы сами попробуйте добраться до этого аппаратного ускорения, и если найдете как — напишите статью. А то элементарная попытка использования GPU для работы TensorFlow Lite приводит только к потерянному времени, а ускорители NPU больше не поддерживаются именно там, где должны были бы. То есть за хайпом вокруг «аппаратного ускорения ИИ» производители создали новую категорию устройств, и теперь стандартно ноутбук будет стоить в 2 раза больше, чем было раньше. А по факту пользоваться этим ускорением будут только компании‑производители, чтобы еще больше заработать денег на пользователях через рекламу, «правильные» модели и торговлю персональными данными.
А мы сегодня запустим TensorFlow Lite на устройствах разного класса и года выпуска и посмотрим, что там с производительностью и ускорением.
Для начала берем модель, подготовленную на предыдущем шаге, и копируем ее в ассеты проекта Android Studio. Можно взять модель с сохраненными весами после обучения (ruLearnModel.tflite), можно сохранить необученную модель (в моем случае ruLearnModel_noWeights.tflite). А можно сделать и то, и другое:
Дальше находим на гитхабе пример с обучением модели на устройстве и копируем оттуда зависимости для нашего проекта. Я посчитал по опыту предыдущей статьи, что лучше использовать старые версии библиотек, которые работают, чем новые, для которых вообще непонятно как писать код.
// Tensorflow lite dependencies implementation 'org.tensorflow:tensorflow-lite:2.9.0' implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0' implementation 'org.tensorflow:tensorflow-lite-support:0.4.2' implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0'
Напомню, что текущая версия TensorFlow 2.17. Наша модель сделана на Python для версии 2.8.
Теперь по архитектуре. Создадим класс-singleton TFLiteHelper, в котором главное действующее лицо — это «interpreter» — экземпляр работающей модели TensorFlow Lite. Он может быть для моего приложения в единственном экземпляре и будет выполнять разные функции — предсказывать или учиться.
private TFLiteHelper(Context context, String courseName) { this.context = context; this.courseName = courseName; interpreter = new Interpreter(loadModelfile(context)); } public static TFLiteHelper getInstance(Context context, String courseName) { if (instance == null) instance = new TFLiteHelper(context, courseName); return instance; }
Передаем в конструктор context и название курса и запоминаем их для дальшейшего использования. Создаем новый интерпретатор, загружая модель из ассетов:
private MappedByteBuffer loadModelfile(Context context) { try { AssetFileDescriptor fileDescriptor = context.getAssets().openFd("ruLearnModel_noWeights.tflite"); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } catch (IOException e) { //добавим обработку исключений позже //в какой ситуации может не найтись модель, которая была в ассетах во время компиляции? //это был риторический вопрос } }
Интерпретатор загружен. Что делать дальше? Можно запустить предсказание, для этого нужны 1) веса модели (они есть, если мы сохранили модель с весами) 2) нормализация (мы можем взять пример уже нормализованных значений из предыдущей статьи. И тогда этот код будет работать:
private float[] doInference(TensorBuffer tb) { int tbSize = tb.getShape()[0]; Map<String, Object> inputs = new HashMap<>(); inputs.put("x", tb.getFloatArray()); Map<String, Object> outputs = new HashMap<>(); FloatBuffer output = FloatBuffer.allocate(tbSize); outputs.put("output_0", output); long millis = System.currentTimeMillis(); interpreter.runSignature(inputs, outputs, "infer"); long xxx = System.currentTimeMillis() - millis; System.out.println("model ran for: " + xxx + " milliseconds"); FloatBuffer buffer = (FloatBuffer) outputs.get("output_0"); float[] outPutArray = buffer.array(); return outPutArray; }
На входе нужен TensorBuffer — внутренний объект для TF Lite. Создается он так:
TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[]{numRowCount, 4}, DataType.FLOAT32); float[] featureArray = new float[]{ -1.2826425, 0.7050209, -0.6746891, 1.5294912 }; tensorBuffer.loadArray(featureArray);
где new int[]{numRowCount, 4} — это форма буфера. То есть он будет из произвольного количества рядов с 4 колонками. DataType.FLOAT32 — это чуть ли не единственный тип данных, нормально поддерживаемый TF Lite. В буфер надо загрузить одномерный массив, и tensorBuffer сам переупакует его в соответствии с заданной формой.
Для выходных значений нужно тоже подготовить буфер. В нашем случае для каждого ряда параметров на выходе будет одно значение float. Поэтому буфер должен быть равен <число рядов * размер float32>. Число рядов равно tb.getShape()[0] — первому элементу формы буфера. Дальше идет загадочная история с упаковкой буферов в Map. Это, по всей видимости, нужно для передачи именованных параметров, как было в Python (x=test_data):
y_lite = infer(x=test_data) y_original = m.infer(x=test_data)
То есть в нашем случае мы записываем имя параметра в ключ Map, а данные — в значение.
Засекаем системное время, запускаем интерпретатор с функцией infer и смотрим, сколько ушло на предсказание. В статье, где мы первый раз запускали модель, на предсказание всей таблицы в цикле по одному ряду ушло 167 миллисекунд. Здесь у нас есть возможность сделать это за один подход и результат равен 11 миллисекунд на том же устройстве.
Обучение на устройстве и вспомогательные функции
Для обучения лучше взять пустую модель, без весов. На самом деле, данных у нас не прибавилось, поэтому учить натренированную модель бессмысленно. Но вначале нужно заняться нормализацией, то есть пересчитать данные по формуле:
Для этого можно использовать следующий код:
private TensorBuffer doNormalize(int numRowCount, float[] mlArray) throws IOException { TensorBuffer notNormalizedTensorBuffer = TensorBuffer.createFixedSize(new int[]{numRowCount, 4}, DataType.FLOAT32); TensorBuffer normalizedTensorBuffer; //объявлять форму не надо, буфер вернет TensorProcessor notNormalizedTensorBuffer.loadArray(mlArray); float[] mean = new float[4]; float[] stddev = new float[4]; getMeansAndStddev(mean, stddev); //внутри метода файловая операция TensorProcessor processor = new TensorProcessor.Builder().add(new NormalizeOp(mean, stddev)).build(); normalizedTensorBuffer = processor.process(notNormalizedTensorBuffer); return normalizedTensorBuffer; }
Метод getMeansAndStddev(mean, stddev) возвращает через параметры средние значения и стандартные отклонения для датасета. У нас 4 параметра в модели, поэтому 4 средних по каждой колонке значений и 4 отклонения. Можно взять его из предыдущей статьи или посчитать самостоятельно самому или применить библиотечные функции из Apache Commons Statistics. В моем случае я читаю эти значения из файла, поэтому throws IOException.
Давайте посмотрим, как выглядит код тренировки модели:
public void doTrain(ArrayList<MLData> input, int num_epoch) throws IOException{ float[] dataArray = mlDataToArrays(input).getData(); float[] resultArray = mlDataToArrays(input).getResults(); calculateMeansAndStandardDeviation(input); //файловые операции TensorBuffer tb_in = doNormalize(input.size(), dataArray); float[][] dataArray2dim = new float[][]{tb_in.getFloatArray()}; float[][] resultArray2dim = new float[][]{resultArray}; long millis = System.currentTimeMillis(); for (int i = 0; i < num_epoch; i++) { Map<String, Object> inputs = new HashMap<>(); inputs.put("x", dataArray2dim); inputs.put("y", resultArray2dim); Map<String, Object> outputs = new HashMap<>(); FloatBuffer loss = FloatBuffer.allocate(1); outputs.put("loss", loss); interpreter.runSignature(inputs, outputs, "train"); } System.out.println("training took " + (System.currentTimeMillis() - millis) + " ms"); saveCheckpoint(); }
Начинается все с преобразования датасета, приходящего в метод из базы SQLite в форме ArrayList<упакованные в объект значения из таблицы с данными, на которых строилась модель>. Для работы TensorFlow нужны буферы или массивы, поэтому я переупаковываю данные соответственно в массивы для значений и результатов, на которых модель будет учиться. Дальше в функции calculateMeansAndStandardDeviation(input) по всему датасету считаю средние значения и стандартное отклонение, записываю эти значения в файл. Дальше почему-то данные и результаты надо переупаковать в двумерные массивы. Почему — непонятно и на самом деле код не работал, пока я не вспомнил, что и в Python было именно так:
test_data = np.array([[-1.2826425, 0.7050209, -0.6746891, 1.5294912], [2.559783, -0.18800573, 0.62913984, -0.8701883 ], [ 2.560703, -0.18800573, 1.672203, -0.87582266], [-1.2808022, 0.7050209, -0.6746891, 0.24507335]],dtype='float32')
Без этого преобразования происходит одна из ошибок TensorFlow с «очень полезными» описаниями типа «TypeError: src data type = 17 is not supported«. В данном случае такая:
Дальше как обычно, упаковка в Map и запуск соответствующей функции interpreter.runSignature(inputs, outputs, «train»). Заодно засекаем время. Что с результатами? (это все под дебагом, в реальности будет намного быстрее)
-
Телефон Xiaomi Mi 1 2017 года, 60 000 Antutu: 2000 мс
-
Телефон Xiaomi 11 Lite 5G 2020 года, 300 000 Antutu: 400 мс
-
Телефон Samsung Galaxy S22 2022 года, 1 000 000 Antutu: 265 мс.
Очевидно, что обучение с нуля такой модели на современном устройстве — не проблема. Можно переучивать модель хоть при каждом старте приложения и пользователь это не заметит. А если использовать Checkpoint’ы, то тем более. Так или иначе, нам нужно сохраниться, поэтому привожу код метода saveCheckpoint():
private void saveCheckpoint() { File outputFile = new File(context.getFilesDir(), courseName + ".ckpt"); Map<String, Object> inputs = new HashMap<>(); inputs.put("checkpoint_path", outputFile.getAbsolutePath()); Map<String, Object> outputs = new HashMap<>(); interpreter.runSignature(inputs, outputs, "save"); }
Метод Inference, который можно сделать публичным
После обучения у нас есть все необходимое, чтобы предоставить метод для предсказания значений, который будет принимать на вход данные из таблицы:
public ArrayList<Float> runInference(ArrayList<MLData> input) throws IOException{ openCheckpoint(); //файловая операция! float[] floatArray = mlDataToArrays(input).getData(); TensorBuffer tb_in = doNormalize(input.size(), floatArray); floatArray = doInference(tb_in); return mlArrayToData(floatArray); } private void openCheckpoint() throws IOException{ File outputFile = new File(context.getFilesDir(), courseName + ".ckpt"); if(outputFile.exists()) { Map<String, Object> inputs = new HashMap<>(); inputs.put("checkpoint_path", outputFile.getAbsolutePath()); Map<String, Object> outputs = new HashMap<>(); interpreter.runSignature(inputs, outputs, "restore"); } else throw new IOException("checkpoint not found"); }
Преобразуем данные из ArrayList в массив, нормализуем и запускаем метод doInference, который мы рассмотрели выше. И все, мы получаем предсказание на базе модели, которую мы обучили на телефоне!
Выводы
Машинное обучение для простых моделей вполне можно запускать на мобильных устройствах. Аппаратное ускорение в этом случае может и не понадобиться, не говоря о том, что даже на GPU мне не удалось запустить интерпретатор ни на одном из имеющихся телефонов. При этом доступ к NPU, даже если он сегодня еще и возможен, будет запрещен для TensorFlow Lite начиная с пятнадцатой версии Android. А ведь TensowFlow Lite — единственный официальный инструмент от Google для машинного обучения на Android!
Сколько миллиардов было потрачено на то, чтобы произвести все эти NPU, к которым просто нет доступа, поразительно. Даже на Orange Pi 3, который у меня дома выполняет функции NAS, есть «built-in AI accelerator NPU with 0.8Tops computing power«.
Несмотря на то, что документация по TensorFlow Lite устаревшая и неполная, все еще можно написать код для своего приложения. Ну а вам будет легче, ведь теперь на Хабре есть туториал 🙂
ссылка на оригинал статьи https://habr.com/ru/articles/839622/
Добавить комментарий