前面文章讲了如何使用pyspark做特征工程
这篇文章我们来讲讲,如何使用pyspark为推荐模型做特征工程。同样的,我们将使用movielens数据集,我们需要进行Sample Label、Movie Features生成以及User Features的生成、最后再split Train&Test Samples。
1. main
代码语言:javascript复制if __name__ == '__main__':
conf = SparkConf().setAppName('featureEngineering').setMaster('local')
spark = SparkSession.builder.config(conf=conf).getOrCreate()
file_path = 'file://资源目录/'
movieResourcesPath = file_path "/webroot/sampledata/movies.csv"
ratingsResourcesPath = file_path "/webroot/sampledata/ratings.csv"
movieSamples = spark.read.format('csv').option('header', 'true').load(movieResourcesPath)
ratingSamples = spark.read.format('csv').option('header', 'true').load(ratingsResourcesPath)
ratingSamplesWithLabel = addSampleLabel(ratingSamples)
ratingSamplesWithLabel.show(5, truncate=False)
samplesWithMovieFeatures = addMovieFeatures(movieSamples, ratingSamplesWithLabel)
samplesWithUserFeatures = addUserFeatures(samplesWithMovieFeatures)
# save samples as csv format
splitAndSaveTrainingTestSamples(samplesWithUserFeatures, file_path "/webroot/sampledata")
# splitAndSaveTrainingTestSamplesByTimeStamp(samplesWithUserFeatures, file_path "/webroot/sampledata")
2. addSampleLabel
我们先对历史评分数据进行统计分析:
代码语言:javascript复制ratingSamples.groupBy('rating').count().orderBy('rating').withColumn('percentage',
F.col('count') / sampleCount).show()
查看历史评分发布,并发现最大比例再 3.0分 和 4.0分:
代码语言:javascript复制 ------ ------ --------------------
|rating| count| percentage|
------ ------ --------------------
| 0.5| 9788|0.008375561978987506|
| 1.0| 45018| 0.03852176636392108|
| 1.5| 11794|0.010092090108314123|
| 2.0| 87084| 0.07451751526135553|
| 2.5| 34269|0.029323879593167432|
| 3.0|323616| 0.27691723185451783|
| 3.5| 74376| 0.06364331811904114|
| 4.0|324804| 0.2779337998593234|
| 4.5| 53388| 0.04568395003414231|
| 5.0|204501| 0.17499088682722966|
------ ------ --------------------
我们将评分3.5及以上的视为用户喜欢该电影,是正样本,反之为负样本:
代码语言:javascript复制def addSampleLabel(ratingSamples):
ratingSamples.show(5, truncate=False)
ratingSamples.printSchema()
sampleCount = ratingSamples.count()
ratingSamples = ratingSamples.withColumn('label', when(F.col('rating') >= 3.5, 1).otherwise(0))
return ratingSamples
代码语言:javascript复制 ------ ------- ------ ---------- -----
|userId|movieId|rating|timestamp |label|
------ ------- ------ ---------- -----
|1 |2 |3.5 |1112486027|1 |
|1 |29 |3.5 |1112484676|1 |
|1 |32 |3.5 |1112484819|1 |
|1 |47 |3.5 |1112484727|1 |
|1 |50 |3.5 |1112484580|1 |
------ ------- ------ ---------- -----
3. addMovieFeatures
代码语言:javascript复制def addMovieFeatures(movieSamples, ratingSamplesWithLabel):
# join 上电影基础特征
samplesWithMovies1 = ratingSamplesWithLabel.join(movieSamples, on=['movieId'], how='left')
3.1 年份特征
代码语言:javascript复制 # add releaseYear,title
samplesWithMovies2 = samplesWithMovies1.withColumn('releaseYear',
udf(extractReleaseYearUdf, IntegerType())('title'))
.withColumn('title', udf(lambda x: x.strip()[:-6].strip(), StringType())('title'))
.drop('title')
其中 extractReleaseYearUdf 是从 title 中截取出年份:
代码语言:javascript复制# title 如 “Toy Story (1995)”
def extractReleaseYearUdf(title):
if not title or len(title.strip()) < 6:
return 1990
else:
yearStr = title.strip()[-5:-1]
return int(yearStr)
3.2 电影类型特征
取前3个类型标签作为特征:
代码语言:javascript复制 samplesWithMovies3 = samplesWithMovies2.withColumn('movieGenre1', split(F.col('genres'), "\|")[0])
.withColumn('movieGenre2', split(F.col('genres'), "\|")[1])
.withColumn('movieGenre3', split(F.col('genres'), "\|")[2])
3.3 评分特征
统计各电影对评分数、平均评分、评分标准差
代码语言:javascript复制 movieRatingFeatures = samplesWithMovies3.groupBy('movieId').agg(F.count(F.lit(1)).alias('movieRatingCount'),
format_number(F.avg(F.col('rating')),
NUMBER_PRECISION).alias(
'movieAvgRating'),
F.stddev(F.col('rating')).alias(
'movieRatingStddev')).fillna(0)
.withColumn('movieRatingStddev', format_number(F.col('movieRatingStddev'), NUMBER_PRECISION))
samplesWithMovies4 = samplesWithMovies3.join(movieRatingFeatures, on=['movieId'], how='left')
samplesWithMovies4.printSchema()
samplesWithMovies4.show(5, truncate=False)
return samplesWithMovies4
代码语言:javascript复制 ------- ------ ------ ---------- ----- --------------------------- ----------- ----------- ----------- ----------- ---------------- -------------- -----------------
|movieId|userId|rating|timestamp |label|genres |releaseYear|movieGenre1|movieGenre2|movieGenre3|movieRatingCount|movieAvgRating|movieRatingStddev|
------- ------ ------ ---------- ----- --------------------------- ----------- ----------- ----------- ----------- ---------------- -------------- -----------------
|296 |1 |4.0 |1112484767|1 |Comedy|Crime|Drama|Thriller|1994 |Comedy |Crime |Drama |14616 |4.17 |0.98 |
|296 |8 |5.0 |833973081 |1 |Comedy|Crime|Drama|Thriller|1994 |Comedy |Crime |Drama |14616 |4.17 |0.98 |
|296 |11 |3.5 |1230858799|1 |Comedy|Crime|Drama|Thriller|1994 |Comedy |Crime |Drama |14616 |4.17 |0.98 |
|296 |13 |5.0 |849082366 |1 |Comedy|Crime|Drama|Thriller|1994 |Comedy |Crime |Drama |14616 |4.17 |0.98 |
|296 |15 |3.0 |840206642 |0 |Comedy|Crime|Drama|Thriller|1994 |Comedy |Crime |Drama |14616 |4.17 |0.98 |
------- ------ ------ ---------- ----- --------------------------- ----------- ----------- ----------- ----------- ---------------- -------------- -----------------
only showing top 5 rows
4. addUserFeatures
用户部分,我们主要对历史近100条数据内的用户观影行为进行相关的特征处理。如最近评分的电影、评分过的电影数、评分过的电影年份、历史评分、最近看过的电影类型等:
代码语言:javascript复制def addUserFeatures(samplesWithMovieFeatures):
extractGenresUdf = udf(extractGenres, ArrayType(StringType()))
samplesWithUserFeatures = samplesWithMovieFeatures
.withColumn('userPositiveHistory',
F.collect_list(when(F.col('label') == 1, F.col('movieId')).otherwise(F.lit(None))).over(
sql.Window.partitionBy("userId").orderBy(F.col("timestamp")).rowsBetween(-100, -1)))
.withColumn("userPositiveHistory", reverse(F.col("userPositiveHistory")))
.withColumn('userRatedMovie1', F.col('userPositiveHistory')[0])
.withColumn('userRatedMovie2', F.col('userPositiveHistory')[1])
.withColumn('userRatedMovie3', F.col('userPositiveHistory')[2])
.withColumn('userRatedMovie4', F.col('userPositiveHistory')[3])
.withColumn('userRatedMovie5', F.col('userPositiveHistory')[4])
.withColumn('userRatingCount',
F.count(F.lit(1)).over(sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)))
.withColumn('userAvgReleaseYear', F.avg(F.col('releaseYear')).over(
sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)).cast(IntegerType()))
.withColumn('userReleaseYearStddev', format_number(F.stddev(F.col("releaseYear")).over(
sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)),NUMBER_PRECISION))
.withColumn("userAvgRating", format_number(
F.avg(F.col("rating")).over(sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)),
NUMBER_PRECISION))
.withColumn("userRatingStddev", format_number(F.stddev(F.col("rating")).over(
sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)),NUMBER_PRECISION))
.withColumn("userGenres", extractGenresUdf(
F.collect_list(when(F.col('label') == 1, F.col('genres')).otherwise(F.lit(None))).over(
sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1))))
.withColumn("userGenre1", F.col("userGenres")[0])
.withColumn("userGenre2", F.col("userGenres")[1])
.withColumn("userGenre3", F.col("userGenres")[2])
.withColumn("userGenre4", F.col("userGenres")[3])
.withColumn("userGenre5", F.col("userGenres")[4])
.drop("genres", "userGenres", "userPositiveHistory")
.filter(F.col("userRatingCount") > 1)
samplesWithUserFeatures.printSchema()
samplesWithUserFeatures.show(5)
samplesWithUserFeatures.filter(samplesWithMovieFeatures['userId'] == 1).orderBy(F.col('timestamp').asc()).show(
truncate=False)
return samplesWithUserFeatures
代码语言:javascript复制 ------- ------ ------ --------- ----- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- --------------- --------------- --------------- --------------- --------------- --------------- ------------------ --------------------- ------------- ---------------- ---------- ---------- ---------- ---------- ----------
|movieId|userId|rating|timestamp|label|releaseYear|movieGenre1|movieGenre2|movieGenre3|movieRatingCount|movieAvgRating|movieRatingStddev|userRatedMovie1|userRatedMovie2|userRatedMovie3|userRatedMovie4|userRatedMovie5|userRatingCount|userAvgReleaseYear|userReleaseYearStddev|userAvgRating|userRatingStddev|userGenre1|userGenre2|userGenre3|userGenre4|userGenre5|
------- ------ ------ --------- ----- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- --------------- --------------- --------------- --------------- --------------- --------------- ------------------ --------------------- ------------- ---------------- ---------- ---------- ---------- ---------- ----------
| 514| 10096| 3.0|954365410| 0| 1994| Comedy| null| null| 1038| 3.50| 0.86| 858| null| null| null| null| 2| 1982| 14.85| 3.50| 0.71| Crime| Drama| null| null| null|
| 608| 10096| 3.0|954365515| 0| 1996| Comedy| Crime| Drama| 9505| 4.09| 0.93| 858| null| null| null| null| 3| 1986| 12.42| 3.33| 0.58| Crime| Drama| null| null| null|
| 50| 10096| 5.0|954365515| 1| 1995| Crime| Mystery| Thriller| 10221| 4.35| 0.75| 858| null| null| null| null| 4| 1988| 11.24| 3.25| 0.50| Crime| Drama| null| null| null|
| 593| 10096| 4.0|954365552| 1| 1991| Crime| Horror| Thriller| 13692| 4.18| 0.85| 50| 858| null| null| null| 5| 1990| 10.12| 3.60| 0.89| Crime| Drama| Mystery| Thriller| null|
| 25| 10096| 2.0|954365571| 0| 1995| Drama| Romance| null| 4684| 3.69| 1.04| 593| 50| 858| null| null| 6| 1990| 9.06| 3.67| 0.82| Crime| Thriller| Drama| Mystery| Horror|
------- ------ ------ --------- ----- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- --------------- --------------- --------------- --------------- --------------- --------------- ------------------ --------------------- ------------- ---------------- ---------- ---------- ---------- ---------- ----------
split Train&Test Samples
随机划分:
代码语言:javascript复制def splitAndSaveTrainingTestSamples(samplesWithUserFeatures, file_path):
smallSamples = samplesWithUserFeatures.sample(0.1)
training, test = smallSamples.randomSplit((0.8, 0.2))
trainingSavePath = file_path '/trainingSamples'
testSavePath = file_path '/testSamples'
training.repartition(1).write.option("header", "true").mode('overwrite')
.csv(trainingSavePath)
test.repartition(1).write.option("header", "true").mode('overwrite')
.csv(testSavePath)
按时间先后划分:
代码语言:javascript复制def splitAndSaveTrainingTestSamplesByTimeStamp(samplesWithUserFeatures, file_path):
smallSamples = samplesWithUserFeatures.sample(0.1).withColumn("timestampLong", F.col("timestamp").cast(LongType()))
quantile = smallSamples.stat.approxQuantile("timestampLong", [0.8], 0.05)
splitTimestamp = quantile[0]
training = smallSamples.where(F.col("timestampLong") <= splitTimestamp).drop("timestampLong")
test = smallSamples.where(F.col("timestampLong") > splitTimestamp).drop("timestampLong")
trainingSavePath = file_path '/trainingSamples'
testSavePath = file_path '/testSamples'
training.repartition(1).write.option("header", "true").mode('overwrite')
.csv(trainingSavePath)
test.repartition(1).write.option("header", "true").mode('overwrite')
.csv(testSavePath)