1. 序列数据的处理
Item2vec 是基于自然语言处理模型 Word2vec 提出的,所以 Item2vec 要处理的是类似文本句子的观影序列:
代码语言:javascript复制def processItemSequence(spark, rawSampleDataPath):
# 读取 ratings 原始数据到 Spark 平台
ratingSamples = spark.read.format("csv").option("header", "true").load(rawSampleDataPath)
# sortUdf,用它实现每个用户的评分记录按照时间戳进行排序
sortUdf = udf(UdfFunction.sortF, ArrayType(StringType()))
# 用 where 语句过滤评分低的评分记录
# 用 groupBy userId 操作聚合每个用户的评分记录,DataFrame 中每条记录是一个用户的评分序列
# sortUdf
# 把每个用户的评分记录处理成一个字符串的形式,供后续训练过程使用。
userSeq = ratingSamples
.where(F.col("rating") >= 3.5)
.groupBy("userId")
.agg(sortUdf(F.collect_list("movieId"), F.collect_list("timestamp")).alias('movieIds'))
.withColumn("movieIdStr", array_join(F.col("movieIds"), " "))
userSeq.select("userId", "movieIdStr").show(5, truncate = False)
return userSeq.select('movieIdStr').rdd.map(lambda x: x[0].split(' '))
代码语言:javascript复制 ------ -------------------------------------------------------------------------------------------------------
|userId|movieIdStr |
------ -------------------------------------------------------------------------------------------------------
|10096 |858 50 593 457 |
|10351 |1 25 32 6 608 52 58 26 30 103 582 588 |
|10436 |661 107 60 1 919 223 260 899 480 592 593 356 588 344 47 595 736 367 500 34 39 141 586 2 750 104 368 317|
|1090 |356 597 919 986 |
|11078 |232 20 296 593 457 150 1 608 50 47 628 922 527 380 588 377 733 10 539 |
------ -------------------------------------------------------------------------------------------------------
2. 训练item2vec
代码语言:javascript复制def trainItem2vec(spark, samples, embLength, embOutputPath, saveToRedis, redisKeyPrefix):
# 设置模型参数
word2vec = Word2Vec().setVectorSize(embLength).setWindowSize(5).setNumIterations(10)
# 训练模型
model = word2vec.fit(samples)
# 训练结束,用模型查找与item"158"最相似的5个item
synonyms = model.findSynonyms("158", 5)
for synonym, cosineSimilarity in synonyms:
print(synonym, cosineSimilarity)
# 保存模型
embOutputDir = '/'.join(embOutputPath.split('/')[:-1])
if not os.path.exists(embOutputDir):
os.makedirs(embOutputDir)
with open(embOutputPath, 'w') as f:
for movie_id in model.getVectors():
vectors = " ".join([str(emb) for emb in model.getVectors()[movie_id]])
f.write(movie_id ":" vectors "n")
return model
代码语言:javascript复制48 0.9553923010826111
256 0.9461638331413269
31 0.9321570992469788
186 0.9115440845489502
355 0.8810520768165588