Рекомендательная система через поиск схожих изображний с помощью Resnet50

Коротко про рекомендательные системы

Глобально существует два подхода в создании рекомендательных систем. Контентно-ориентированная и коллаборативная фильтрация. Основополагающее предположение подхода коллаборативной фильтрации заключается в том, что если А и В покупают аналогичные продукты, А, скорее всего, купит продукт, который купил В, чем продукт, который купил случайный человек. В отличие от контентно-ориентированного подхода, здесь нет признаков, соответствующих пользователям или предметам. Рекомендательная система базируется на матрице взаимодействий пользователей. Контентно-ориентированная система базируется на знаниях о предметах. Например если пользователь смотрит шелковые футболки возможно ему будет интересно посмотреть на другие шелковые футболки.

В этой статье я хочу рассказать о подходе который основан на поиске схожих изображений. Зачем подготавливать дополнительнительные данные если почти все основные характеристики некоторых товаров, например одежда, можно отобразить на изображении.

исключение softmax
исключение softmax

Суть подхода заключается в извлчении признаков из изображений товаров. С помощью сверточной сети, в своем примере я использовал Resnet50, так как вектор признаков resnet имеет относительно небольшую размерность. Извлечь вектор признаков с помощью обученой сети очень просто. Нужно просто исключить softmax классификатор именно он определяет к какому классу относится изображение и мы получим на выходе вектор признаков. Далее необходимо сравнивать векторы и искать похожие. Чем более схожи изображения тем меньше евклидово расстояние между векторами.

Код и датасет

Датасет можно скачать отсюда https://www.kaggle.com/datasets/paramaggarwal/fashion-product-images-small

Инициализации обученой restnet50 из библиотеки pytorch и извлечении признаков из датасета

from torchvision.io import read_image from torchvision.models import resnet50, ResNet50_Weights import torch import glob import pickle from tqdm import tqdm from PIL import Image  def pil_loader(path):     # Некоторые изображения из датасета представленны не в RGB формате, необходимо их конверитровать в RGB     with open(path, 'rb') as f:         img = Image.open(f)         return img.convert('RGB')   # Инициализация модели обученой на датасете imagenet weights = ResNet50_Weights.DEFAULT model = resnet50(weights=weights) model.eval() preprocess = weights.transforms()  use_precomputed_embeddings = True emb_filename = 'fashion_images_embs.pickle' if use_precomputed_embeddings:      with open(emb_filename, 'rb') as fIn:         img_names, img_emb_tensors = pickle.load(fIn)       print("Images:", len(img_names)) else:     img_names  = list(glob.glob('images/*.jpg'))     img_emb = []     # извлечение признаков из изображений в датасете. У меня на CPU заняло около часа     for image in tqdm(img_names):         img_emb.append(             model(preprocess(pil_loader(image)).unsqueeze(0)).squeeze(0).detach().numpy()         )     img_emb_tensors = torch.tensor(img_emb)          with open(emb_filename, 'wb') as handle:         pickle.dump([img_names, img_emb_tensors], handle, protocol=pickle.HIGHEST_PROTOCOL)

Функция которая создает поисковый индекс с помощью faiss и уменьшает размерность векторов признаков

# Для сравнения векторов используется faiss import faiss                    from sklearn.decomposition import PCA  def build_compressed_index(n_features):     pca = PCA(n_components=n_features)     pca.fit(img_emb_tensors)     compressed_features = pca.transform(img_emb_tensors)     dataset = np.float32(compressed_features)     d = dataset.shape[1]     nb = dataset.shape[0]     xb = dataset      index_compressed = faiss.IndexFlatL2(d)     index_compressed.add(xb)     return [pca, index_compressed]  

Хэлперы для отображения результатов

import matplotlib.pyplot as plt import matplotlib.image as mpimg  def main_image(img_path, desc):     plt.imshow(mpimg.imread(img_path))     plt.xlabel(img_path.split('.')[0] + '_Original Image',fontsize=12)     plt.title(desc,fontsize=20)     plt.show()  def similar_images(indices, suptitle):     plt.figure(figsize=(15,10), facecolor='white')     plotnumber = 1         for index in indices[0:4]:         if plotnumber<=len(indices) :             ax = plt.subplot(2,2,plotnumber)             plt.imshow(mpimg.imread(img_names[index]))             plt.xlabel(img_names[index],fontsize=12)             plotnumber+=1     plt.suptitle(suptitle,fontsize=15)     plt.tight_layout()

Сама функция поиска. Принимает на вход количество признаков, что бы можно было поэксперементировать с достаточным количеством признаков

import numpy as np # поиск, можно искать по индексу из предварительно извлеченных изображений или передать новое изображение def search(query, factors):     if(type(query) == str):         img_path = query     else:         img_path = img_names[query]     one_img_emb = torch.tensor(model(preprocess(read_image(img_path)).unsqueeze(0)).squeeze(0).detach().numpy())     main_image(img_path, 'Query')     compressor, index_compressed = build_compressed_index(factors)     D, I = index_compressed.search(np.float32(compressor.transform([one_img_emb.detach().numpy()])),5)     similar_images(I[0][1:], "faiss compressed " + str(factors))

Виновник торжества. Вызов поиска

search(100,300) search("t-shirt.jpg", 500)

Выводы

В итоге за пару часов можно собрать довольно качественную рекомендательную систему основаную на схожести изображений, чего достаточно для некоторых случаев. Изображения не требуют предварительной подготовки, разметки и какой то метаинформации что значительно упрощает процесс.

Для повышения качества рекомендаций можно дообучить некторые слои сети на используемом датасете.


ссылка на оригинал статьи https://habr.com/ru/post/675052/

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *