图片相似度检索设计

2024-08-26 17:01:02 浏览数 (2)

背景

相似度检索的应用场景颇多,不管是互联网生态下的内容理解还是工业界质量检、人脸对比等,向量相似度检索技术的核心是通过向量表征的感兴趣区域并通过向量距离计算衡量输入样本的相似度。针对图片的相似度检索,主要包含图片裁剪、特征提取、PCA、聚类计算、相似度距离计算6个步骤,通常业界有6类常具有代表性的向量表征算法,他们是Word2vec,Doc2vec,DeepWalk,Graph2Vec,Asm2Vec,Log2Vec。本文基于公司的业务驱动,具体聊聊CV领域图片相似度检索技术的原理和实践案例。

模型——DINOv2

基于图像特征提取目前为止不二选择Meta AI提供的Dinov2模型,该模型基于Vit基础架构,DINOv2 可以抽取到强大的图像特征,且在下游任务上不需要微调,这使得它适合作为许多不同的应用中新的 BackBone。

Dinov2最核心创新点主要是:

1.自监督训练数据集构建

如下图所示Dinov2构建一套自监督训练数据集构建pipeline,包含数据源、去重、自监督图像检索(聚类检索),数据集产出。Dinov2将开源数据集和网上大量的未经标注的数据集经过后处理后(PCA 哈希去重、NSFW 过滤和模糊可识别的人脸)形成数据池,并基于该数据池,提取图像Embedding特征,基于Embedding采用聚类算法将相似向量的图片放在统一簇中,DinoV2根据查询图像的Embedding在聚类产生的簇中检索N张最相似的图像。最后将这些相似的图像和查询图像一起预训练,最终形成1.42亿张图像,命名为 LVD-142M 数据集

自监督检索技术生成数据集LVD-142M

2.自监督训练方式——知识蒸馏

DINOv2 使用了两种目标函数来训练网络

  • 第一种: Image-level 的目标函数,其使用 ViT 的 cls token 的特征,通过比较从同一图像的不同部分得到的学生网络和教师网络的 cls token 输出来计算交叉熵损失
  • 第二种: Patch-level 的目标函数,通过随机屏蔽学生网络输入的一些 patch(不是教师网络),并对每个被屏蔽的 patch 的特征进行交叉熵损失的计算。

这两种目标函数的权重单独调整,以便在不同尺度上获得更好的性能。

Teacher模型和 student模型网络结构相同,但是参数不同;

图片裁剪:监督student模型学习到从局部到全局的响应

  • local views:局部视角,student模型接收所有的crops图;
  • global views: 全局视角,teacher模型接收的只是global views的裁剪图;

数据库选择

业界最成熟的向量数据库要属Faiss了,国内常用的向量数据库参考Milvus。

数据库

成熟度

功能

性能

是否开源

厂商

Faiss

Facebook AI团队研发开源数据库,目前最为成熟的近似近邻搜索库

支持相似度搜索支持聚类支持向量做簇内归一化支持基于聚类、PCA的检索方式 (分布式检索不支持)

支持十亿量级向量存储

Meta

Milvus

国内一款开源的向量相似度搜索引擎,2019年正式开源

支持数据分区分片、数据持久化、增量数据摄取、标量向量混合查询、time travel等功能集成了FAISS、SPTAG等向量搜索库

支持万亿级向量数据建立索引

Zilliz

Example

代码语言:javascript复制
import torch
import torch.nn as nn
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.image as mpimg 
from PIL import Image
from sklearn.decomposition import PCA
import matplotlib

import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
 
def get_features_vec(img_path, vit_model):
    patch_h = 75
    patch_w = 50
    feat_dim = 384 # vits14

    transform = T.Compose([
        T.GaussianBlur(9, sigma=(0.1, 2.0)),
        T.Resize((patch_h * 14, patch_w * 14)),
        T.CenterCrop((patch_h * 14, patch_w * 14)),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

    if vit_model == 'dinov2_vits14': 
        dinov2_vits14 = torch.hub.load('', 'dinov2_vits14',source='local').cuda()
    else:
        dinov2_vits14 = torch.hub.load('', 'dinov2_vitg14',source='local').cuda()
     
    features = torch.zeros(4, patch_h * patch_w, feat_dim)
    imgs_tensor = torch.zeros(4, 3, patch_h * 14, patch_w * 14).cuda()

    img = Image.open(img_path).convert('RGB')
    imgs_tensor[0] = transform(img)[:3]
    with torch.no_grad():
        features_dict = dinov2_vits14.forward_features(imgs_tensor)
        features = features_dict['x_norm_patchtokens']
        
    features_means = features.mean(dim=1))
    return features, features_means


def means_cos(image_features1, image_features2):
    cos = nn.CosineSimilarity(dim=0)
    sim = cos(image_features1[0],image_features2[0]).item()
    print('Similarity:', sim)
    sim = (sim 1)/2
    return sim

def main():
    img_path1 = f'luse1.jpg'
    #img_path2 = f'outdoor.jpg'
    img_path2 = f'luse2.jpg'

    features1, features1_means = get_features_vec(img_path1, 'dinov2_vits14')
    features2, features2_means = get_features_vec(img_path2, 'dinov2_vits14')
   
    sim = means_cos(features1_means, features2_means)
    print("sim:", sim)

if __name__=='__main__':
    main()

参考:

CLIP与DINOv2的图像相似度对比-CSDN博客

Similarities:精准相似度计算与语义匹配搜索工具包,多维度实现多种算法,覆盖文本、图像等领域,支持文搜、图搜文、图搜图匹配搜索-腾讯云开发者社区-腾讯云

0 人点赞