xgboost SparkMLlibPipeline.scala代码如下:(注意运行时要按照特征目录格式组织:src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala )
代码语言:javascript复制package ml.dmlc.xgboost4j.scala.example.spark
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.tuning._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
object SparkMLlibPipeline {
def main(args: Array[String]): Unit = {
if (args.length != 3) {
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
val inputPath = args(0)
val nativeModelPath = args(1)
val pipelineModelPath = args(2)
val spark = SparkSession
.appName("XGBoost4J-Spark Pipeline Example")
// Load dataset
val schema = new StructType(Array(
StructField("sepal length", DoubleType, true),
StructField("sepal width", DoubleType, true),
StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true),
StructField("class", StringType, true)))
val rawInput = spark.read.schema(schema).csv(inputPath)
// Split training and test dataset
val Array(training, test) = rawInput.randomSplit(Array(0.8, 0.2), 123)
// Build ML pipeline, it includes 4 stages:
// 1, Assemble all features into a single vector column.
// 2, From string label to indexed double label.
// 3, Use XGBoostClassifier to train classification model.
// 4, Convert indexed double label back to original string label.
val assembler = new VectorAssembler()
.setInputCols(Array("sepal length", "sepal width", "petal length", "petal width"))
val labelIndexer = new StringIndexer()
val booster = new XGBoostClassifier(
Map("eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2
val labelConverter = new IndexToString()
val pipeline = new Pipeline()
.setStages(Array(assembler, labelIndexer, booster, labelConverter))
val model = pipeline.fit(training)
// Batch prediction
val prediction = model.transform(test)
// Model evaluation
val evaluator = new MulticlassClassificationEvaluator()
val accuracy = evaluator.evaluate(prediction)
println("The model accuracy is : " accuracy)
// Tune model using cross validation
val paramGrid = new ParamGridBuilder()
.addGrid(booster.maxDepth, Array(3, 8))
.addGrid(booster.eta, Array(0.2, 0.6))
val cv = new CrossValidator()
val cvModel = cv.fit(training)
val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel].stages(2)
println("The params of best XGBoostClassification model : "
println("The training summary of best XGBoostClassificationModel : "
// Export the XGBoostClassificationModel as local XGBoost model,
// then you can load it back in local Python environment.
// ML pipeline persistence
// Load a saved model and serving
val model2 = PipelineModel.load(pipelineModelPath)
代码语言:javascript复制<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<description>This is a boilerplate maven project to start using Spark in Scala</description>
<!-- Put the Scala version of the cluster -->
<!-- repository to add org.apache.spark -->
<id>GitHub Repo</id>
<name>GitHub Repo</name>
<!-- see http://davidb.github.com/scala-maven-plugin -->
<!-- https://mvnrepository.com/artifact/net.alchim31.maven/scala-maven-plugin -->
<!-- <version>3.1.3</version> -->
<!-- <version>3.4.6</version> -->
<!-- If you have classpath issue like NoDefClassError,... -->
<!-- useManifestOnlyJar>false</useManifestOnlyJar -->
<!-- "package" command plugin -->
<!-- Scala and Spark dependencies -->
<!-- https://mvnrepository.com/artifact/org.scala-lang/scala-library -->
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
<!-- https://mvnrepository.com/artifact/com.github.scopt/scopt_2.11 -->
代码语言:javascript复制mvn clean package
代码语言:javascript复制spark-2.4.3-bin-hadoop2.7/bin/spark-submit --class ml.dmlc.xgboost4j.scala.example.spark.SparkMLlibPipeline --jars /***/scala_workSpace/test/xgboost4j-example_2.11-1.0.0-jar-with-dependencies.jar /***/scala_workSpace/test/xgboost4j-example_2.11-1.0.0.jar /tmp/rd/lp/iris.data /***/scala_workSpace/test/nativeModel /tmp/rd/lp/pipelineModel