0. 前言
背景:需要在pyspark上例行化word2vec,但是加载预训练的词向量是一个大问题,因此需要先上传到HDFS,然后通过代码再获取。调研后发现pyspark虽然有自己的word2vec方法,但是好像无法加载预训练txt词向量。
因此大致的步骤应分为两步:
1.从hdfs获取词向量文件
2.对pyspark dataframe内的数据做分词 向量化的处理
1. 获取词向量文件
开源的词向量文件很多,基本上都是key-value形式的txt文档,以腾讯AI Lab的词向量为例。
(https://ai.tencent.com/ailab/nlp/en/embedding.html)
首先需要将词向量txt文件上传到hdfs里,接着在代码里通过使用sparkfile来实现把文件下发到每一个worker:
代码语言:javascript复制from pyspark.sql import SparkSession
from pyspark import SparkFiles
# 将hdfs的词向量下发到每一个worker
sparkContext = spark.sparkContext
sparkContext.addPyFile("hdfs://******/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt")
# 使用文件的方法:就和本地使用文件时"/***/***"一样
SparkFiles.get("tencent-ailab-embedding-zh-d100-v0.2.0-s.txt")
这一步的耗时主要在词向量下发到每一个worker这一步上。如果词向量文件较大可能耗时较高。
2. 分词 向量化的处理
预训练词向量下发到每一个worker后,下一步就是对数据进行分词和获取词向量,采用udf函数来实现以上操作:
代码语言:javascript复制import pyspark.sql.functions as f
# 定义分词以及向量化的udf
@f.udf(StringType())
def generate_embedding(title, subtitle=None):
cut_title = jieba.lcut(title.lower())
if subtitle is None:
cut_sentence = cut_title
else:
cut_subtitle = jieba.lcut(title.lower())
cut_sentence = cut_title cut_subtitle
res_embed = []
for word in cut_sentence:
# 未登录单词这里选择不处理, 也可以用unk替代
try:
res_embed.append(model.get_vector(word))
except:
pass
# 对词向量做avg_pooling
if len(res_embed)==0:
avg_vectors = np.zeros(100)
else:
res_embed_arr = np.array(res_embed)
avg_vectors = res_embed_arr.mean(axis=(0))
avg_vectors = np.round(avg_vectors,decimals=6)
# 转换成所需要的格式
tmp = []
for j in avg_vectors:
tmp.append(str(j))
output = ','.join(tmp)
return output
这里如果需要使用用户自定义jieba词典的时候就会有一个问题,我怎么在pyspark上实现jieba.load_userdict()
- 如果在pyspark里面直接使用该方法,加载的词典在执行udf的时候并没有真正的产生作用,从而导致无效加载。
- 另外如果在udf里面直接使用该方法,会导致计算每一行dataframe的时候都去加载一次词典,导致重复加载耗时过长。
- 还有一些其他方法,比如将jieba作为参数传入柯里化的udf或者新建一个jieba的Tokenizer实例,作为参数传入udf或者作为全局变量等同样也不行,因为jieba中有线程锁,无法序列化。
因此需要一种方式,在每一个worker上只加载一次。
首先在main方法里将用户自定义词典下发到每一个worker:
代码语言:javascript复制# 将hdfs的词典下发到每一个worker
sparkContext.addPyFile("hdfs://xxxxxxx/word_dict.txt")
接着在udf内首行添加jieba.dt.initialized
判断是否需要加载词典:
if not jieba.dt.initialized:
jieba.load_userdict(SparkFiles.get("word_dict.txt"))
至此完美解决这个问题~
参考:
https://github.com/fxsjy/jieba/issues/387