问题
在我们的舆情系统里,有一个需求是这样的:
从近期的标注的文章(数量比较稳定,约5万,数据存在MySQL中)里找到跟目标文章集合(数量不稳定,约1万,数据存在MySQL)里最相似的一篇文章,也就是每个目标集合的文章都要找到一个最相似的文章。
每一篇文章在入库前已经计算好simhash码。
现状
最笨的方法当前是当然是两层循环直接计算,但是这时间上显然是不可能的,1万乘以5万,那就是5亿次计算!
当然我们也没那么傻,已经优化成了使用numpy的矩阵运算,性能确实提升了很多,但是事实上客户反馈有时还是很慢,特别是数据比较多的时候。
优化方案
优化方案可以有多个:
方案1:把近期标注的数据直接迁移到ES里
这个很直接,但是对于我们来说有几个问题:
- 阿里云的ES得升级到7的版本(目前使用es6),但是阿里云没有能平滑升级的方式;
- 系统需要做比较大的改动,短时间很难完成;
- 即使迁移到es7,目标集合1万多次查询,时间肯定也很可观。
方案2:使用向量数据库(如Milvus)
这等于引入了一个新的存储,增加了系统的复杂度,保证各个存储之间的数据同步就是大问题。
方案3:使用向量引擎(如Faiss)
Faiss在FB刚开源出来的时候,就知道了,只是一直没有机会去使用,在我们的场景下一开始也没有使用,是因为考虑到要对近期标注的文章建索引,但是这个索引并不是稳定的,经常需要更新,建索引可能会得不偿失。另外,刚开始系统数据量不大,时间耗时问题也没有太明显。
只是最近又收到比较多客户的反馈,说这个等待时间比较久的问题,才重新测试这个Faiss。
测试发现,这个库是可以解决我们的问题的,大概是因为我们的目标集合也是有万级的数量的,平摊建索引的时间还是划算的。
Faiss的使用
安装:
代码语言:javascript复制# 安装依赖
apt install libopenblas-dev -y
apt install libomp-dev -y
# 安装Faiss
pip install faiss
生成模拟数据:
代码语言:javascript复制from numpy import random
data = random.randint(2, size=(50000, 64))
print(data.shape)
构建索引:
代码语言:javascript复制def build_index(data, index_type="IndexFlatL2", quantizer_type="IndexFlatL2", nlist=1, m=1, nbits=8, dist_metric=0, *args, **kwargs):
d = len(data[0])
index = None
if index_type == "IndexFlatL2":
index = faiss.IndexFlatL2(d)
elif index_type == "IndexFlatIP":
index = faiss.IndexFlatIP(d)
elif index_type == "IndexLSH":
index = faiss.IndexLSH(d, nbits)
elif index_type == "IndexPQ":
index = faiss.IndexPQ(d, m, nbits)
if index is not None:
index.add(data)
return index
quantizer = None
if quantizer_type == "IndexFlatL2":
quantizer = faiss.IndexFlatL2(d)
elif quantizer_type == "IndexFlatIP":
quantizer = faiss.IndexFlatIP(d)
elif quantizer_type == "IndexLSH":
quantizer = faiss.IndexLSH(d, nbits)
elif quantizer_type == "IndexPQ":
quantizer = faiss.IndexPQ(d, m, nbits)
metric = faiss.METRIC_INNER_PRODUCT
if dist_metric == 1:
metric = faiss.METRIC_L2
if index_type == "IndexIVFFlat":
index = faiss.IndexIVFFlat(quantizer, d, nlist, metric)
elif index_type == "IndexIVFPQ":
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
index.train(data)
index.add(data)
return index
start = time.time()
index = build_index(data.astype(np.float32))
print("Time:", time.time()-start)
在测试机上,这个耗时约为15毫秒(当然这是因为我们选择的是索引IndexFlatL2),可以接受。
模拟目标集合进行测试:
代码语言:javascript复制# 模拟一个批次,10000条数据
aid = random.randint(2, size=(10000, 64))
print(aid.shape)
# 查询相似
index.nprobe = 1
start = time.time()
_, match = index.search(aid.astype(np.float32), 1)
print("time:", time.time()-start)
# 观察效果
res, max_val = [], 0
for i in range(aid.shape[0]):
t = aid[i] == data[match[i][0]]
n = np.count_nonzero(t)
res.append(n)
if n > max_val:
max_val = n
print(np.average(res), max_val)
这里预测1万个向量,耗时才1秒左右,这个时间也是可以接受的,已经是大大低于原来的耗时了。
基于这个测试结果,可以估计,原来分钟级的操作会变成秒级的操作!
Faiss参考资料
1. Faiss入门及经验记录 https://zhuanlan.zhihu.com/p/357414033