版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u014365862/article/details/100146395
java/scala生成jar一般采用有两种sbt和maven,本人介绍通过maven生成jar的方式,同时可以查看git:https://github.com/MachineLP/Spark-/tree/master/scala-xgboost。
xgboost SparkMLlibPipeline.scala代码如下:(注意运行时要按照特征目录格式组织:src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala )
代码语言:javascript复制package ml.dmlc.xgboost4j.scala.example.spark
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.tuning._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}
// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)
object SparkMLlibPipeline {
def main(args: Array[String]): Unit = {
if (args.length != 3) {
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
sys.exit(1)
}
val inputPath = args(0)
val nativeModelPath = args(1)
val pipelineModelPath = args(2)
val spark = SparkSession
.builder()
.appName("XGBoost4J-Spark Pipeline Example")
.getOrCreate()
// Load dataset
val schema = new StructType(Array(
StructField("sepal length", DoubleType, true),
StructField("sepal width", DoubleType, true),
StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true),
StructField("class", StringType, true)))
val rawInput = spark.read.schema(schema).csv(inputPath)
// Split training and test dataset
val Array(training, test) = rawInput.randomSplit(Array(0.8, 0.2), 123)
// Build ML pipeline, it includes 4 stages:
// 1, Assemble all features into a single vector column.
// 2, From string label to indexed double label.
// 3, Use XGBoostClassifier to train classification model.
// 4, Convert indexed double label back to original string label.
val assembler = new VectorAssembler()
.setInputCols(Array("sepal length", "sepal width", "petal length", "petal width"))
.setOutputCol("features")
val labelIndexer = new StringIndexer()
.setInputCol("class")
.setOutputCol("classIndex")
.fit(training)
val booster = new XGBoostClassifier(
Map("eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2
)
)
booster.setFeaturesCol("features")
booster.setLabelCol("classIndex")
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("realLabel")
.setLabels(labelIndexer.labels)
val pipeline = new Pipeline()
.setStages(Array(assembler, labelIndexer, booster, labelConverter))
val model = pipeline.fit(training)
// Batch prediction
val prediction = model.transform(test)
prediction.show(false)
// Model evaluation
val evaluator = new MulticlassClassificationEvaluator()
evaluator.setLabelCol("classIndex")
evaluator.setPredictionCol("prediction")
val accuracy = evaluator.evaluate(prediction)
println("The model accuracy is : " accuracy)
// Tune model using cross validation
val paramGrid = new ParamGridBuilder()
.addGrid(booster.maxDepth, Array(3, 8))
.addGrid(booster.eta, Array(0.2, 0.6))
.build()
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(3)
val cvModel = cv.fit(training)
val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel].stages(2)
.asInstanceOf[XGBoostClassificationModel]
println("The params of best XGBoostClassification model : "
bestModel.extractParamMap())
println("The training summary of best XGBoostClassificationModel : "
bestModel.summary)
// Export the XGBoostClassificationModel as local XGBoost model,
// then you can load it back in local Python environment.
bestModel.nativeBooster.saveModel(nativeModelPath)
// ML pipeline persistence
model.write.overwrite().save(pipelineModelPath)
// Load a saved model and serving
val model2 = PipelineModel.load(pipelineModelPath)
model2.transform(test).show(false)
}
}
pom.xml文件如下:(注意添加正确的依赖)
代码语言:javascript复制<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-example_2.11</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<name>${project.artifactId}</name>
<description>This is a boilerplate maven project to start using Spark in Scala</description>
<inceptionYear>2010</inceptionYear>
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<encoding>UTF-8</encoding>
<scala.tools.version>2.11</scala.tools.version>
<!-- Put the Scala version of the cluster -->
<scala.version>2.11.12</scala.version>
<scala.binary.version>2.11</scala.binary.version>
<spark.version>2.4.3</spark.version>
</properties>
<!-- repository to add org.apache.spark -->
<repositories>
<repository>
<id>cloudera-repo-releases</id>
<url>https://repository.cloudera.com/artifactory/repo/</url>
</repository>
<repository>
<id>GitHub Repo</id>
<name>GitHub Repo</name>
<url>https://raw.githubusercontent.com/CodingCat/xgboost/maven-repo/</url>
</repository>
</repositories>
<build>
<sourceDirectory>src/main/scala</sourceDirectory>
<testSourceDirectory>src/test/scala</testSourceDirectory>
<plugins>
<plugin>
<!-- see http://davidb.github.com/scala-maven-plugin -->
<!-- https://mvnrepository.com/artifact/net.alchim31.maven/scala-maven-plugin -->
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<!-- <version>3.1.3</version> -->
<version>4.0.2</version>
<!-- <version>3.4.6</version> -->
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>testCompile</goal>
</goals>
<configuration>
<args>
<arg>-dependencyfile</arg>
<arg>${project.build.directory}/.scala_dependencies</arg>
</args>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.13</version>
<configuration>
<useFile>false</useFile>
<disableXmlReport>true</disableXmlReport>
<!-- If you have classpath issue like NoDefClassError,... -->
<!-- useManifestOnlyJar>false</useManifestOnlyJar -->
<includes>
<include>**/*Test.*</include>
<include>**/*Suite.*</include>
</includes>
</configuration>
</plugin>
<!-- "package" command plugin -->
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>2.4.1</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark</artifactId>
<version>0.90</version>
</dependency>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.90</version>
</dependency>
<!-- Scala and Spark dependencies -->
<!-- https://mvnrepository.com/artifact/org.scala-lang/scala-library -->
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.apache.velocity</groupId>
<artifactId>velocity</artifactId>
<version>1.7</version>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
<version>1.2</version>
</dependency>
<!-- https://mvnrepository.com/artifact/com.github.scopt/scopt_2.11 -->
<dependency>
<groupId>com.github.scopt</groupId>
<artifactId>scopt_2.11</artifactId>
<version>3.5.0</version>
</dependency>
</dependencies>
</project>
之后运行生成jar包:
代码语言:javascript复制mvn clean package
最后,在集群上提交任务即可:
代码语言:javascript复制spark-2.4.3-bin-hadoop2.7/bin/spark-submit --class ml.dmlc.xgboost4j.scala.example.spark.SparkMLlibPipeline --jars /***/scala_workSpace/test/xgboost4j-example_2.11-1.0.0-jar-with-dependencies.jar /***/scala_workSpace/test/xgboost4j-example_2.11-1.0.0.jar /tmp/rd/lp/iris.data /***/scala_workSpace/test/nativeModel /tmp/rd/lp/pipelineModel