曾经在15、16年那会儿使用Spark做机器学习,那时候pyspark并不成熟,做特征工程主要还是写scala。后来进入阿里工作,特征处理基本上使用PAI 可视化特征工程组件 ODPS SQL,复杂的话才会自己写python处理。最近重新学习了下pyspark,笔记下如何使用pyspark做特征工程。
我们使用movielens的数据进行,oneHotEncoder、multiHotEncoder和Numerical features的特征处理。
main
代码语言:javascript复制from pyspark import SparkConf
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, QuantileDiscretizer, MinMaxScaler
from pyspark.ml.linalg import VectorUDT, Vectors
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import functions as F
if __name__ == '__main__':
conf = SparkConf().setAppName('featureEngineering').setMaster('local')
spark = SparkSession.builder.config(conf=conf).getOrCreate()
file_path = 'file:///资源文件夹路径'
movieResourcesPath = file_path "/webroot/sampledata/movies.csv"
movieSamples = spark.read.format('csv').option('header', 'true').load(movieResourcesPath)
print("Raw Movie Samples:")
movieSamples.show(10)
movieSamples.printSchema()
print("OneHotEncoder Example:")
oneHotEncoderExample(movieSamples)
print("MultiHotEncoder Example:")
multiHotEncoderExample(movieSamples)
print("Numerical features Example:")
ratingsResourcesPath = file_path "/webroot/sampledata/ratings.csv"
ratingSamples = spark.read.format('csv').option('header', 'true').load(ratingsResourcesPath)
ratingFeatures(ratingSamples)
我们先来看看“movies.csv” 和 “ratings.csv” 数据长什么样子吧:
代码语言:javascript复制movies samples:
------- -------------------- --------------------
|movieId| title| genres|
------- -------------------- --------------------
| 1| Toy Story (1995)|Adventure|Animati...|
| 2| Jumanji (1995)|Adventure|Childre...|
| 3|Grumpier Old Men ...| Comedy|Romance|
| 4|Waiting to Exhale...|Comedy|Drama|Romance|
| 5|Father of the Bri...| Comedy|
------- -------------------- --------------------
ratings samples:
------ ------- ------ ----------
|userId|movieId|rating| timestamp|
------ ------- ------ ----------
| 1| 2| 3.5|1112486027|
| 1| 29| 3.5|1112484676|
| 1| 32| 3.5|1112484819|
| 1| 47| 3.5|1112484727|
| 1| 50| 3.5|1112484580|
------ ------- ------ ----------
oneHotEncoder
我们对movieId进行oneHotEncoder:
代码语言:javascript复制def oneHotEncoderExample(movieSamples):
# 把movieId的值,转为int直接作为movieIdNumber编号
samplesWithIdNumber = movieSamples.withColumn("movieIdNumber", F.col("movieId").cast(IntegerType()))
encoder = OneHotEncoder(inputCols=["movieIdNumber"], outputCols=['movieIdVector'], dropLast=False)
oneHotEncoderSamples = encoder.fit(samplesWithIdNumber).transform(samplesWithIdNumber)
oneHotEncoderSamples.printSchema()
oneHotEncoderSamples.show(5)
代码语言:javascript复制OneHotEncoder Example:
root
|-- movieId: string (nullable = true)
|-- title: string (nullable = true)
|-- genres: string (nullable = true)
|-- movieIdNumber: integer (nullable = true)
|-- movieIdVector: vector (nullable = true)
------- -------------------- -------------------- ------------- ----------------
|movieId| title| genres|movieIdNumber| movieIdVector|
------- -------------------- -------------------- ------------- ----------------
| 1| Toy Story (1995)|Adventure|Animati...| 1|(1001,[1],[1.0])|
| 2| Jumanji (1995)|Adventure|Childre...| 2|(1001,[2],[1.0])|
| 3|Grumpier Old Men ...| Comedy|Romance| 3|(1001,[3],[1.0])|
| 4|Waiting to Exhale...|Comedy|Drama|Romance| 4|(1001,[4],[1.0])|
| 5|Father of the Bri...| Comedy| 5|(1001,[5],[1.0])|
------- -------------------- -------------------- ------------- ----------------
only showing top 5 rows
multiHotEncoder
我们再对电影类型‘genres’进行multiHotEncoder:
代码语言:javascript复制def multiHotEncoderExample(movieSamples):
# 对genres进行切分,一行变多行
samplesWithGenre = movieSamples.select("movieId", "title", explode(
split(F.col("genres"), "\|").cast(ArrayType(StringType()))).alias('genre'))
print("samplesWithGenre Samples:")
samplesWithGenre.printSchema()
samplesWithGenre.show(5)
代码语言:javascript复制samplesWithGenre Samples:
root
|-- movieId: string (nullable = true)
|-- title: string (nullable = true)
|-- genre: string (nullable = true)
------- ---------------- ---------
|movieId| title| genre|
------- ---------------- ---------
| 1|Toy Story (1995)|Adventure|
| 1|Toy Story (1995)|Animation|
| 1|Toy Story (1995)| Children|
| 1|Toy Story (1995)| Comedy|
| 1|Toy Story (1995)| Fantasy|
------- ---------------- ---------
only showing top 5 rows
代码语言:javascript复制 # genre 编码
genreIndexer = StringIndexer(inputCol="genre", outputCol="genreIndex")
StringIndexerModel = genreIndexer.fit(samplesWithGenre)
genreIndexSamples = StringIndexerModel.transform(samplesWithGenre).withColumn("genreIndexInt",
F.col("genreIndex").cast(IntegerType()))
# 计算编码向量大小
indexSize = genreIndexSamples.agg(max(F.col("genreIndexInt"))).head()[0] 1
# 根据 movieId 聚合genreIndexInt
processedSamples = genreIndexSamples.groupBy('movieId').agg(
F.collect_list('genreIndexInt').alias('genreIndexes')).withColumn("indexSize", F.lit(indexSize))
# 生成vector
finalSample = processedSamples.withColumn("vector",
udf(array2vec, VectorUDT())(F.col("genreIndexes"), F.col("indexSize")))
print("finalSample Samples:")
finalSample.printSchema()
finalSample.show(5)
代码语言:javascript复制finalSample Samples:
root
|-- movieId: string (nullable = true)
|-- genreIndexes: array (nullable = true)
| |-- element: integer (containsNull = false)
|-- indexSize: integer (nullable = false)
|-- vector: vector (nullable = true)
------- ------------ --------- --------------------
|movieId|genreIndexes|indexSize| vector|
------- ------------ --------- --------------------
| 296|[1, 5, 0, 3]| 19|(19,[0,1,3,5],[1....|
| 467| [1]| 19| (19,[1],[1.0])|
| 675| [4, 0, 3]| 19|(19,[0,3,4],[1.0,...|
| 691| [1, 2]| 19|(19,[1,2],[1.0,1.0])|
| 829| [1, 10, 14]| 19|(19,[1,10,14],[1....|
------- ------------ --------- --------------------
only showing top 5 rows
其中生成vector的udf array2vec :
代码语言:javascript复制def array2vec(genreIndexes, indexSize):
genreIndexes.sort()
fill_list = [1.0 for _ in range(len(genreIndexes))]
# 稀疏向量存储 indexSize,有值的Indexes,对应Indexes上的填充值
return Vectors.sparse(indexSize, genreIndexes, fill_list)
Numerical features
对于Numerical features,我们可以进行分桶或者标准化。在这里,先我们读取“ratings.csv”数据,统计各电影被评价的次数以及平均得分:
代码语言:javascript复制def ratingFeatures(ratingSamples):
# calculate average movie rating score and rating count
movieFeatures = ratingSamples.groupBy('movieId').agg(F.count(F.lit(1)).alias('ratingCount'),
F.avg("rating").alias("avgRating"))
.withColumn('avgRatingVec', udf(lambda x: Vectors.dense(x), VectorUDT())('avgRating'))
print("movieFeatures:")
movieFeatures.show(5)
代码语言:javascript复制movieFeatures:
------- ----------- ------------------ --------------------
|movieId|ratingCount| avgRating| avgRatingVec|
------- ----------- ------------------ --------------------
| 296| 14616| 4.165606185002737| [4.165606185002737]|
| 467| 174|3.4367816091954024|[3.4367816091954024]|
| 829| 402|2.6243781094527363|[2.6243781094527363]|
| 691| 254|3.1161417322834644|[3.1161417322834644]|
| 675| 6|2.3333333333333335|[2.3333333333333335]|
------- ----------- ------------------ --------------------
only showing top 5 rows
再对被评价的次数进行分桶,对平均得分进行标准化:
代码语言:javascript复制 # bucketing
ratingCountDiscretizer = QuantileDiscretizer(numBuckets=100, inputCol="ratingCount", outputCol="ratingCountBucket")
# Normalization
ratingScaler = MinMaxScaler(inputCol="avgRatingVec", outputCol="scaleAvgRating")
pipelineStage = [ratingCountDiscretizer, ratingScaler]
featurePipeline = Pipeline(stages=pipelineStage)
movieProcessedFeatures = featurePipeline.fit(movieFeatures).transform(movieFeatures)
movieProcessedFeatures.show(5)
代码语言:javascript复制 ------- ----------- ------------------ -------------------- ----------------- --------------------
|movieId|ratingCount| avgRating| avgRatingVec|ratingCountBucket| scaleAvgRating|
------- ----------- ------------------ -------------------- ----------------- --------------------
| 296| 14616| 4.165606185002737| [4.165606185002737]| 57.0|[0.9170998054196596]|
| 467| 174|3.4367816091954024|[3.4367816091954024]| 21.0|[0.7059538707722662]|
| 829| 402|2.6243781094527363|[2.6243781094527363]| 32.0|[0.4705944962973248]|
| 691| 254|3.1161417322834644|[3.1161417322834644]| 26.0|[0.6130620985364005]|
| 675| 6|2.3333333333333335|[2.3333333333333335]| 3.0|[0.38627664627161...|
------- ----------- ------------------ -------------------- ----------------- --------------------
only showing top 5 rows