scala-sparkML学习笔记:serializable custom transformer with spark-scala

2019-08-31 19:25:02 浏览数 (1)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/u014365862/article/details/100146543

有时候在构建pipeline时,sparkML中有些功能不存在需要自己定义,可以参考这个样例:

(src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/OwnMLlibPipeline.scala

代码语言:javascript复制
/*
-------------------------------------------------
   Description :  Serializable Custom Transformer with Spark 2.0 (Scala)
   Author :       liupeng
   Date :         2019/08/29
-------------------------------------------------
 */

package ml.dmlc.xgboost4j.scala.example.spark

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.PipelineStage
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType



import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{ Param, ParamMap }
import org.apache.spark.ml.util.{ DefaultParamsReadable, DefaultParamsWritable, Identifiable }
import org.apache.spark.sql.{ DataFrame, Dataset }
import org.apache.spark.sql.types.StructType


class ColRenameTransformer(override val uid: String) extends Transformer with DefaultParamsWritable {

  def this() = this(Identifiable.randomUID("ColRenameTransformer"))
  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)
  def getOutputCol: String = getOrDefault(outputCol)

  val inputCol = new Param[String](this, "inputCol", "input column")
  val outputCol = new Param[String](this, "outputCol", "output column")

  override def transform(dataset: Dataset[_]): DataFrame = {
    val outCol = extractParamMap.getOrElse(outputCol, "output")
    val inCol = extractParamMap.getOrElse(inputCol, "input")

    dataset.drop(outCol).withColumnRenamed(inCol, outCol)
  }

  override def copy(extra: ParamMap): ColRenameTransformer = defaultCopy(extra)
  override def transformSchema(schema: StructType): StructType = schema
}

object ColRenameTransformer extends DefaultParamsReadable[ColRenameTransformer] {
  override def load(path: String): ColRenameTransformer = super.load(path)
}


object OwnMLlibPipeline {

  def main(args: Array[String]): Unit = {

    val pipelineModelPath = args(0)

    val spark = SparkSession.builder().getOrCreate()
    val data = spark.createDataFrame(Seq(
                     ("hi,there", 1),
                     ("a,b,c", 2),
                     ("no", 3) )).toDF("myInputCol", "id")
    data.show(false)
    val myTransformer = new ColRenameTransformer().setInputCol( "id" ).setOutputCol( "lpid" )
    println(s"Original data has ${data.count()} rows.")
    // val output = myTransformer.transform(data)
    // println(s"Output data has ${output.count()} rows.")
    // output.show(false)
  
    val pipeline = new Pipeline().setStages(Array( myTransformer ))
    val model = pipeline.fit(data)
    // ML pipeline persistence
    model.write.overwrite().save(pipelineModelPath)
    // Load a saved model and serving
    val model2 = PipelineModel.load(pipelineModelPath)
    model2.transform(data).show(false)
  }
}

运行结果:

代码语言:javascript复制
input:
 ---------- --- 
|myInputCol|id |
 ---------- --- 
|hi,there  |1  |
|a,b,c     |2  |
|no        |3  |
 ---------- --- 
res:
 ---------- ---- 
|myInputCol|lpid|
 ---------- ---- 
|hi,there  |1   |
|a,b,c     |2   |
|no        |3   |
 ---------- ---- 

0 人点赞