pyspark做movielens推荐模型特征工程

2021-03-20 14:17:25 浏览数 (1)

前面文章讲了如何使用pyspark做特征工程

这篇文章我们来讲讲,如何使用pyspark为推荐模型做特征工程。同样的,我们将使用movielens数据集,我们需要进行Sample Label、Movie Features生成以及User Features的生成、最后再split Train&Test Samples。

1. main

代码语言:javascript复制
if __name__ == '__main__':
    conf = SparkConf().setAppName('featureEngineering').setMaster('local')
    spark = SparkSession.builder.config(conf=conf).getOrCreate()
    file_path = 'file://资源目录/'
    movieResourcesPath = file_path   "/webroot/sampledata/movies.csv"
    ratingsResourcesPath = file_path   "/webroot/sampledata/ratings.csv"
    movieSamples = spark.read.format('csv').option('header', 'true').load(movieResourcesPath)
    ratingSamples = spark.read.format('csv').option('header', 'true').load(ratingsResourcesPath)
    ratingSamplesWithLabel = addSampleLabel(ratingSamples)
    ratingSamplesWithLabel.show(5, truncate=False)
    samplesWithMovieFeatures = addMovieFeatures(movieSamples, ratingSamplesWithLabel)
    samplesWithUserFeatures = addUserFeatures(samplesWithMovieFeatures)
    # save samples as csv format
    splitAndSaveTrainingTestSamples(samplesWithUserFeatures, file_path   "/webroot/sampledata")
    # splitAndSaveTrainingTestSamplesByTimeStamp(samplesWithUserFeatures, file_path   "/webroot/sampledata")

2. addSampleLabel

我们先对历史评分数据进行统计分析:

代码语言:javascript复制
ratingSamples.groupBy('rating').count().orderBy('rating').withColumn('percentage',
                                                                     F.col('count') / sampleCount).show()

查看历史评分发布,并发现最大比例再 3.0分 和 4.0分:

代码语言:javascript复制
 ------ ------ -------------------- 
|rating| count|          percentage|
 ------ ------ -------------------- 
|   0.5|  9788|0.008375561978987506|
|   1.0| 45018| 0.03852176636392108|
|   1.5| 11794|0.010092090108314123|
|   2.0| 87084| 0.07451751526135553|
|   2.5| 34269|0.029323879593167432|
|   3.0|323616| 0.27691723185451783|
|   3.5| 74376| 0.06364331811904114|
|   4.0|324804|  0.2779337998593234|
|   4.5| 53388| 0.04568395003414231|
|   5.0|204501| 0.17499088682722966|
 ------ ------ -------------------- 

我们将评分3.5及以上的视为用户喜欢该电影,是正样本,反之为负样本:

代码语言:javascript复制
def addSampleLabel(ratingSamples):
    ratingSamples.show(5, truncate=False)
    ratingSamples.printSchema()
    sampleCount = ratingSamples.count()
    ratingSamples = ratingSamples.withColumn('label', when(F.col('rating') >= 3.5, 1).otherwise(0))
    return ratingSamples
代码语言:javascript复制
 ------ ------- ------ ---------- ----- 
|userId|movieId|rating|timestamp |label|
 ------ ------- ------ ---------- ----- 
|1     |2      |3.5   |1112486027|1    |
|1     |29     |3.5   |1112484676|1    |
|1     |32     |3.5   |1112484819|1    |
|1     |47     |3.5   |1112484727|1    |
|1     |50     |3.5   |1112484580|1    |
 ------ ------- ------ ---------- ----- 

3. addMovieFeatures

代码语言:javascript复制
def addMovieFeatures(movieSamples, ratingSamplesWithLabel):
    # join 上电影基础特征
    samplesWithMovies1 = ratingSamplesWithLabel.join(movieSamples, on=['movieId'], how='left')

3.1 年份特征

代码语言:javascript复制
    # add releaseYear,title
    samplesWithMovies2 = samplesWithMovies1.withColumn('releaseYear',
                                                       udf(extractReleaseYearUdf, IntegerType())('title')) 
        .withColumn('title', udf(lambda x: x.strip()[:-6].strip(), StringType())('title')) 
        .drop('title')

其中 extractReleaseYearUdf 是从 title 中截取出年份:

代码语言:javascript复制
# title 如 “Toy Story (1995)”
def extractReleaseYearUdf(title):
    if not title or len(title.strip()) < 6:
        return 1990
    else:
        yearStr = title.strip()[-5:-1]
    return int(yearStr)

3.2 电影类型特征

取前3个类型标签作为特征:

代码语言:javascript复制
    samplesWithMovies3 = samplesWithMovies2.withColumn('movieGenre1', split(F.col('genres'), "\|")[0]) 
        .withColumn('movieGenre2', split(F.col('genres'), "\|")[1]) 
        .withColumn('movieGenre3', split(F.col('genres'), "\|")[2])

3.3 评分特征

统计各电影对评分数、平均评分、评分标准差

代码语言:javascript复制
    movieRatingFeatures = samplesWithMovies3.groupBy('movieId').agg(F.count(F.lit(1)).alias('movieRatingCount'),
                                                                    format_number(F.avg(F.col('rating')),
                                                                                  NUMBER_PRECISION).alias(
                                                                        'movieAvgRating'),
                                                                    F.stddev(F.col('rating')).alias(
                                                                        'movieRatingStddev')).fillna(0) 
        .withColumn('movieRatingStddev', format_number(F.col('movieRatingStddev'), NUMBER_PRECISION))
    samplesWithMovies4 = samplesWithMovies3.join(movieRatingFeatures, on=['movieId'], how='left')
    samplesWithMovies4.printSchema()
    samplesWithMovies4.show(5, truncate=False)
    return samplesWithMovies4
代码语言:javascript复制
 ------- ------ ------ ---------- ----- --------------------------- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- 
|movieId|userId|rating|timestamp |label|genres                     |releaseYear|movieGenre1|movieGenre2|movieGenre3|movieRatingCount|movieAvgRating|movieRatingStddev|
 ------- ------ ------ ---------- ----- --------------------------- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- 
|296    |1     |4.0   |1112484767|1    |Comedy|Crime|Drama|Thriller|1994       |Comedy     |Crime      |Drama      |14616           |4.17          |0.98             |
|296    |8     |5.0   |833973081 |1    |Comedy|Crime|Drama|Thriller|1994       |Comedy     |Crime      |Drama      |14616           |4.17          |0.98             |
|296    |11    |3.5   |1230858799|1    |Comedy|Crime|Drama|Thriller|1994       |Comedy     |Crime      |Drama      |14616           |4.17          |0.98             |
|296    |13    |5.0   |849082366 |1    |Comedy|Crime|Drama|Thriller|1994       |Comedy     |Crime      |Drama      |14616           |4.17          |0.98             |
|296    |15    |3.0   |840206642 |0    |Comedy|Crime|Drama|Thriller|1994       |Comedy     |Crime      |Drama      |14616           |4.17          |0.98             |
 ------- ------ ------ ---------- ----- --------------------------- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- 
only showing top 5 rows

4. addUserFeatures

用户部分,我们主要对历史近100条数据内的用户观影行为进行相关的特征处理。如最近评分的电影、评分过的电影数、评分过的电影年份、历史评分、最近看过的电影类型等:

代码语言:javascript复制
def addUserFeatures(samplesWithMovieFeatures):
    extractGenresUdf = udf(extractGenres, ArrayType(StringType()))
    samplesWithUserFeatures = samplesWithMovieFeatures 
        .withColumn('userPositiveHistory',
                    F.collect_list(when(F.col('label') == 1, F.col('movieId')).otherwise(F.lit(None))).over(
                        sql.Window.partitionBy("userId").orderBy(F.col("timestamp")).rowsBetween(-100, -1))) 
        .withColumn("userPositiveHistory", reverse(F.col("userPositiveHistory"))) 
        .withColumn('userRatedMovie1', F.col('userPositiveHistory')[0]) 
        .withColumn('userRatedMovie2', F.col('userPositiveHistory')[1]) 
        .withColumn('userRatedMovie3', F.col('userPositiveHistory')[2]) 
        .withColumn('userRatedMovie4', F.col('userPositiveHistory')[3]) 
        .withColumn('userRatedMovie5', F.col('userPositiveHistory')[4]) 
        .withColumn('userRatingCount',
                    F.count(F.lit(1)).over(sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1))) 
        .withColumn('userAvgReleaseYear', F.avg(F.col('releaseYear')).over(
        sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)).cast(IntegerType())) 
        .withColumn('userReleaseYearStddev', format_number(F.stddev(F.col("releaseYear")).over(
        sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)),NUMBER_PRECISION)) 
        .withColumn("userAvgRating", format_number(
        F.avg(F.col("rating")).over(sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)),
        NUMBER_PRECISION)) 
        .withColumn("userRatingStddev", format_number(F.stddev(F.col("rating")).over(
        sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)),NUMBER_PRECISION))
        .withColumn("userGenres", extractGenresUdf(
        F.collect_list(when(F.col('label') == 1, F.col('genres')).otherwise(F.lit(None))).over(
            sql.Window.partitionBy('userId').orderBy('timestamp').rowsBetween(-100, -1)))) 
        .withColumn("userGenre1", F.col("userGenres")[0]) 
        .withColumn("userGenre2", F.col("userGenres")[1]) 
        .withColumn("userGenre3", F.col("userGenres")[2]) 
        .withColumn("userGenre4", F.col("userGenres")[3]) 
        .withColumn("userGenre5", F.col("userGenres")[4]) 
        .drop("genres", "userGenres", "userPositiveHistory") 
        .filter(F.col("userRatingCount") > 1)
    samplesWithUserFeatures.printSchema()
    samplesWithUserFeatures.show(5)
    samplesWithUserFeatures.filter(samplesWithMovieFeatures['userId'] == 1).orderBy(F.col('timestamp').asc()).show(
        truncate=False)
    return samplesWithUserFeatures
代码语言:javascript复制
 ------- ------ ------ --------- ----- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- --------------- --------------- --------------- --------------- --------------- --------------- ------------------ --------------------- ------------- ---------------- ---------- ---------- ---------- ---------- ---------- 
|movieId|userId|rating|timestamp|label|releaseYear|movieGenre1|movieGenre2|movieGenre3|movieRatingCount|movieAvgRating|movieRatingStddev|userRatedMovie1|userRatedMovie2|userRatedMovie3|userRatedMovie4|userRatedMovie5|userRatingCount|userAvgReleaseYear|userReleaseYearStddev|userAvgRating|userRatingStddev|userGenre1|userGenre2|userGenre3|userGenre4|userGenre5|
 ------- ------ ------ --------- ----- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- --------------- --------------- --------------- --------------- --------------- --------------- ------------------ --------------------- ------------- ---------------- ---------- ---------- ---------- ---------- ---------- 
|    514| 10096|   3.0|954365410|    0|       1994|     Comedy|       null|       null|            1038|          3.50|             0.86|            858|           null|           null|           null|           null|              2|              1982|                14.85|         3.50|            0.71|     Crime|     Drama|      null|      null|      null|
|    608| 10096|   3.0|954365515|    0|       1996|     Comedy|      Crime|      Drama|            9505|          4.09|             0.93|            858|           null|           null|           null|           null|              3|              1986|                12.42|         3.33|            0.58|     Crime|     Drama|      null|      null|      null|
|     50| 10096|   5.0|954365515|    1|       1995|      Crime|    Mystery|   Thriller|           10221|          4.35|             0.75|            858|           null|           null|           null|           null|              4|              1988|                11.24|         3.25|            0.50|     Crime|     Drama|      null|      null|      null|
|    593| 10096|   4.0|954365552|    1|       1991|      Crime|     Horror|   Thriller|           13692|          4.18|             0.85|             50|            858|           null|           null|           null|              5|              1990|                10.12|         3.60|            0.89|     Crime|     Drama|   Mystery|  Thriller|      null|
|     25| 10096|   2.0|954365571|    0|       1995|      Drama|    Romance|       null|            4684|          3.69|             1.04|            593|             50|            858|           null|           null|              6|              1990|                 9.06|         3.67|            0.82|     Crime|  Thriller|     Drama|   Mystery|    Horror|
 ------- ------ ------ --------- ----- ----------- ----------- ----------- ----------- ---------------- -------------- ----------------- --------------- --------------- --------------- --------------- --------------- --------------- ------------------ --------------------- ------------- ---------------- ---------- ---------- ---------- ---------- ---------- 

split Train&Test Samples

随机划分:

代码语言:javascript复制
def splitAndSaveTrainingTestSamples(samplesWithUserFeatures, file_path):
    smallSamples = samplesWithUserFeatures.sample(0.1)
    training, test = smallSamples.randomSplit((0.8, 0.2))
    trainingSavePath = file_path   '/trainingSamples'
    testSavePath = file_path   '/testSamples'
    training.repartition(1).write.option("header", "true").mode('overwrite') 
        .csv(trainingSavePath)
    test.repartition(1).write.option("header", "true").mode('overwrite') 
        .csv(testSavePath)

按时间先后划分:

代码语言:javascript复制
def splitAndSaveTrainingTestSamplesByTimeStamp(samplesWithUserFeatures, file_path):
    smallSamples = samplesWithUserFeatures.sample(0.1).withColumn("timestampLong", F.col("timestamp").cast(LongType()))
    quantile = smallSamples.stat.approxQuantile("timestampLong", [0.8], 0.05)
    splitTimestamp = quantile[0]
    training = smallSamples.where(F.col("timestampLong") <= splitTimestamp).drop("timestampLong")
    test = smallSamples.where(F.col("timestampLong") > splitTimestamp).drop("timestampLong")
    trainingSavePath = file_path   '/trainingSamples'
    testSavePath = file_path   '/testSamples'
    training.repartition(1).write.option("header", "true").mode('overwrite') 
        .csv(trainingSavePath)
    test.repartition(1).write.option("header", "true").mode('overwrite') 
        .csv(testSavePath)

0 人点赞