使用Faiss优化两个集合之间相似文章计算的问题

2022-11-07 13:04:03 浏览数 (1)

问题


在我们的舆情系统里,有一个需求是这样的:

从近期的标注的文章(数量比较稳定,约5万,数据存在MySQL中)里找到跟目标文章集合(数量不稳定,约1万,数据存在MySQL)里最相似的一篇文章,也就是每个目标集合的文章都要找到一个最相似的文章。

每一篇文章在入库前已经计算好simhash码。

现状


最笨的方法当前是当然是两层循环直接计算,但是这时间上显然是不可能的,1万乘以5万,那就是5亿次计算!

当然我们也没那么傻,已经优化成了使用numpy的矩阵运算,性能确实提升了很多,但是事实上客户反馈有时还是很慢,特别是数据比较多的时候。

优化方案


优化方案可以有多个:

方案1:把近期标注的数据直接迁移到ES里

这个很直接,但是对于我们来说有几个问题:

  1. 阿里云的ES得升级到7的版本(目前使用es6),但是阿里云没有能平滑升级的方式;
  2. 系统需要做比较大的改动,短时间很难完成;
  3. 即使迁移到es7,目标集合1万多次查询,时间肯定也很可观。

方案2:使用向量数据库(如Milvus)

这等于引入了一个新的存储,增加了系统的复杂度,保证各个存储之间的数据同步就是大问题。

方案3:使用向量引擎(如Faiss)

Faiss在FB刚开源出来的时候,就知道了,只是一直没有机会去使用,在我们的场景下一开始也没有使用,是因为考虑到要对近期标注的文章建索引,但是这个索引并不是稳定的,经常需要更新,建索引可能会得不偿失。另外,刚开始系统数据量不大,时间耗时问题也没有太明显。

只是最近又收到比较多客户的反馈,说这个等待时间比较久的问题,才重新测试这个Faiss。

测试发现,这个库是可以解决我们的问题的,大概是因为我们的目标集合也是有万级的数量的,平摊建索引的时间还是划算的。

Faiss的使用


安装:

代码语言:javascript复制
# 安装依赖
apt install libopenblas-dev -y
apt install libomp-dev -y
# 安装Faiss
pip install faiss

生成模拟数据:

代码语言:javascript复制
from numpy import random

data = random.randint(2, size=(50000, 64))
print(data.shape)

构建索引:

代码语言:javascript复制
def build_index(data, index_type="IndexFlatL2", quantizer_type="IndexFlatL2", nlist=1, m=1, nbits=8, dist_metric=0, *args, **kwargs):
    d = len(data[0])
    index = None
    if index_type == "IndexFlatL2":
        index = faiss.IndexFlatL2(d)
    elif index_type == "IndexFlatIP":
        index = faiss.IndexFlatIP(d)
    elif index_type == "IndexLSH":
        index = faiss.IndexLSH(d, nbits)
    elif index_type == "IndexPQ":
        index = faiss.IndexPQ(d, m, nbits)
    if index is not None:
        index.add(data)
        return index
    
    quantizer = None
    if quantizer_type == "IndexFlatL2":
        quantizer = faiss.IndexFlatL2(d)
    elif quantizer_type == "IndexFlatIP":
        quantizer = faiss.IndexFlatIP(d)
    elif  quantizer_type == "IndexLSH":
        quantizer = faiss.IndexLSH(d, nbits)
    elif quantizer_type == "IndexPQ":
        quantizer = faiss.IndexPQ(d, m, nbits)
    
    metric = faiss.METRIC_INNER_PRODUCT
    if dist_metric == 1:
        metric = faiss.METRIC_L2
    if index_type == "IndexIVFFlat":
        index = faiss.IndexIVFFlat(quantizer, d, nlist, metric)
    elif index_type == "IndexIVFPQ":
        index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
    index.train(data)
    index.add(data)
    return index

start = time.time()
index = build_index(data.astype(np.float32))
print("Time:", time.time()-start)

在测试机上,这个耗时约为15毫秒(当然这是因为我们选择的是索引IndexFlatL2),可以接受。

模拟目标集合进行测试:

代码语言:javascript复制
# 模拟一个批次,10000条数据
aid = random.randint(2, size=(10000, 64))
print(aid.shape)

# 查询相似
index.nprobe = 1
start = time.time()
_, match = index.search(aid.astype(np.float32), 1)
print("time:", time.time()-start)

# 观察效果
res, max_val = [], 0
for i in range(aid.shape[0]):
    t = aid[i] == data[match[i][0]]
    n = np.count_nonzero(t)
    res.append(n)
    if n > max_val:
        max_val = n
    
print(np.average(res), max_val)

这里预测1万个向量,耗时才1秒左右,这个时间也是可以接受的,已经是大大低于原来的耗时了。

基于这个测试结果,可以估计,原来分钟级的操作会变成秒级的操作!

Faiss参考资料


1. Faiss入门及经验记录 https://zhuanlan.zhihu.com/p/357414033

0 人点赞