在前一篇文章中,简单的以为可以将计算总数的部分去掉,但是如果去掉了总数的计算,就没法计算每个文章的热度(就是相似文章的数量),这点客户没法接受。
所以,还是得想法子继续优化。
1、回顾前一个版本的算法
在前一个版本的算法里,我们使用numpy实现成矩阵运算的形式,在数据量1.7万的时候,还是比直接使用es查询要慢。
分析这其中的原因,主要是计算量太大了,而es则可以充分利用已有的索引机制来提升效率。
那有没有办法降低算法的时间复杂度呢?
剪枝、分治、贪心、。。。。
典型的分治法,如二分法排序
2、分析数据的特征
一个simhash编码如:
代码语言:javascript复制0111001111110101011100100000111100110110010011011000010101110110
关于simhash的文章网上很多,这里不细说。
对于两个相似的文章,我们有理由假设他们的simhash码会有相当长的一个子串是相同的,这个假设显然是合理的。
于是,我们就可以采用分治法,将一个64位的simhash字符串均分成若干段,例如如果我们将上面的simhash串切成长度相等的4段:
代码语言:javascript复制0111001111110101
0111001000001111
0011011001001101
1000010101110110
根据前面的假设,那么另一篇相似的文章的simhash码也均分成4段,则至少有一段是和上面对应的一段是完全一样的。(对此,你可能有不同疑问,如果有这样一个文章,它的simhash码和这个文章的只是差了4个位,但是这4个位刚好又导致了这4个子串都不相同。确实,这是可能发生的,如果有人转载文章时,这里改一点,那里改一点,总体改动不多,但是改动的地方却是比较均匀,这就可能发生刚才的情形。不过对于我们来说,可以不考虑这种情况,因为这种情况本省发生的概率就比较小,其实为了降低时间延迟,牺牲一点点精度也是可以接受的)
因此切分之后的simhash子串的数量最大可能有2的16次方个(65536),每个子串可能对应若干个文章ID,最后把有交集的文章ID合并到一个类即可。
3、第三个优化版本
上面说起来好像挺简单,实现起来还是有点点复杂的。
3.1 数据切分:
代码语言:javascript复制import time
import json
import numpy as np
from typing import List
# 常量配置
sim_thr = 0.85 * 64 # 85%相似度阈值
# 加载1.7万的文章id及simhash值
with open('./article_simhash_17k.json') as f:
data = json.load(f)
print(len(data))
start = time.time()
new_data = {}
for item in data:
simhash, article_id = item['simhash'], item['id']
simhash = ['0']*(64-len(simhash)) list(simhash)
# 不足64位的前面补0,文章id放到第一个值
item = [article_id] simhash
simhash = ''.join(simhash)
for i in range(4): # 这个方式耗时:28s(会有一定的精度损失)
new_data.setdefault(str(i) simhash[i*16:i*16 16], []).append(item)
# for i in range(8): # 这个方式整体时间:39s
# if i*8 16 <= 64:
# new_data.setdefault(str(i) simhash[i*8:i*8 16], []).append(item)
all_np_data = [np.array(vals) for vals in new_data.values() if len(vals) > 1]
print('time: ', time.time()-start, len(new_data), len(all_np_data))
这个比较简单,耗时约0.46秒。
有一个小技巧,切分生成key时,加上了一个序号ID,这样就能保证只有相同位置的段才会完全相同。
3.2 初步聚类
对每一个切分好的段内部的文章进行聚类:
代码语言:javascript复制def cluster(np_data) -> List[List[str]]:
"""将相似的文章聚类在一起"""
results = []
while np_data.shape[0] > 1:
curr_row = np_data[0]
np_data = np_data[1:]
counts = np.count_nonzero(curr_row == np_data, axis=1)
if len(counts) > 0:
ls = list(np_data[counts > sim_thr][:, 0])
ls.append(curr_row[0])
results.append(ls)
np_data = np_data[counts <= sim_thr]
return results
all_cls = [] # List[List[str]], 文章id的初步聚类结果,每个类可以包含多个文章id
for np_data in all_np_data:
all_cls = cluster(np_data) # TODO 这里是可以并行的,待优化(不过这里耗时不是太多)
results = {item['id']: [] for item in data} # 每个id对应一个类别,空列表表示没有分到具体的类别里
for cls_id, aids in enumerate(all_cls): # 循环处理每个原始类别
for aid in aids: # 把该类的文章id都对应到类别id
results[aid].append(cls_id) # 一个文章可以包含多个初始分类
print('cluster time: ', time.time()-start, len(all_cls))
其中的cluster函数和前一篇文章的思路基本是一致的。
3.3 合并有交集的类别
上面的算法已经聚合了很多的列表,一篇文章最多可能被分到了4个类别上,需要对有交集的类别进行合并:
代码语言:javascript复制# 合并所有有交集的分类id
# TODO 这个步骤最耗时间,超过99%的时间都是消耗在这里
clses = [set(vals) for vals in results.values() if len(vals) > 1]
merge_cls = np.array(range(len(all_cls))) # 每个类别对应的原始id
print('before merge, for len:', len(clses))
for i, cls in enumerate(clses):
for cls_j in clses[i 1:]:
# 集合计算比较慢
# if len(cls.intersection(clses[j])) == 0:
if len(cls.intersection(cls_j)) == 0:
continue
# 有交集则对应的全修改为最小值
# cls = cls.union(cls_j) # time: 39s
cls.update(cls_j) # time: 38s
# 获取最小的类别
merge_cls[list(cls)] = min(cls)
print('merge time', time.time()-start, len(set(merge_cls)))
其实就是循环处理每一个类别,将后面和该类别有交集的类别都合并在一起,并取类别中最小的类别id作为新的类别id。
从打印的数据可以看到,这里循环的长度有1.3万多,两重循环就得2亿左右次,整个程序超过99%的时间都是消耗在了这里。
3.4 计算文章类别及热度
有了前面一步的结果,这个倒是比较容易实现的了:
代码语言:javascript复制# 计算每个分类id对应的文章id列表
cls_results = {}
for aid, _cls in results.items():
if len(_cls) > 0:
cls_id = merge_cls[_cls[0]] # 合并后的id
cls_results.setdefault(cls_id, []).append(aid)
# 生成最后的结果
# TODO 需要验证这里的相似文章的相似度怎么样
cls_results = [vals for vals in cls_results.values()]
print('hot max:', max([len(vals) for vals in cls_results]))
# 计算热度
cls_results = [(vals[0], len(vals)) for vals in cls_results]
# 补全热度为1的数据
for _id, _cls in results.items():
if len(_cls) == 0:
cls_results.append((_id, 1))
print('time: ', time.time() - start, ' len: ', len(results), len(cls_results))
assert len(data) == sum([val[1] for val in cls_results])
最后,在我的笔记本上运行大概耗时27秒,比之前的4分钟还是下降了很多的。
4、总结
通过分治法,牺牲一点精度,换来了时间消耗的减少,这是值得的。不过执行一次还需要27秒,这个还是有点多的,而且随着数据量的增大,这个耗时可能是指数级的,还需要继续优化。
合并类别那里还是有不少优化空间的,待续。。。