scala-sparkML学习笔记:xgboost进行分布式训练

2019-08-31 19:25:47 浏览数 (1)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

本文链接:https://blog.csdn.net/u014365862/article/details/100146395

java/scala生成jar一般采用有两种sbt和maven,本人介绍通过maven生成jar的方式,同时可以查看git:https://github.com/MachineLP/Spark-/tree/master/scala-xgboost。

xgboost SparkMLlibPipeline.scala代码如下:(注意运行时要按照特征目录格式组织:src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkMLlibPipeline.scala )

代码语言:javascript复制
package ml.dmlc.xgboost4j.scala.example.spark

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.tuning._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._

import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassifier, XGBoostClassificationModel}

// this example works with Iris dataset (https://archive.ics.uci.edu/ml/datasets/iris)

object SparkMLlibPipeline {

  def main(args: Array[String]): Unit = {

    if (args.length != 3) {
      println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
      sys.exit(1)
    }

    val inputPath = args(0)
    val nativeModelPath = args(1)
    val pipelineModelPath = args(2)

    val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

    // Load dataset
    val schema = new StructType(Array(
      StructField("sepal length", DoubleType, true),
      StructField("sepal width", DoubleType, true),
      StructField("petal length", DoubleType, true),
      StructField("petal width", DoubleType, true),
      StructField("class", StringType, true)))

    val rawInput = spark.read.schema(schema).csv(inputPath)

    // Split training and test dataset
    val Array(training, test) = rawInput.randomSplit(Array(0.8, 0.2), 123)

    // Build ML pipeline, it includes 4 stages:
    // 1, Assemble all features into a single vector column.
    // 2, From string label to indexed double label.
    // 3, Use XGBoostClassifier to train classification model.
    // 4, Convert indexed double label back to original string label.
    val assembler = new VectorAssembler()
      .setInputCols(Array("sepal length", "sepal width", "petal length", "petal width"))
      .setOutputCol("features")
    val labelIndexer = new StringIndexer()
      .setInputCol("class")
      .setOutputCol("classIndex")
      .fit(training)
    val booster = new XGBoostClassifier(
      Map("eta" -> 0.1f,
        "max_depth" -> 2,
        "objective" -> "multi:softprob",
        "num_class" -> 3,
        "num_round" -> 100,
        "num_workers" -> 2
      )
    )
    booster.setFeaturesCol("features")
    booster.setLabelCol("classIndex")
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("realLabel")
      .setLabels(labelIndexer.labels)

    val pipeline = new Pipeline()
      .setStages(Array(assembler, labelIndexer, booster, labelConverter))
    val model = pipeline.fit(training)

    // Batch prediction
    val prediction = model.transform(test)
    prediction.show(false)

    // Model evaluation
    val evaluator = new MulticlassClassificationEvaluator()
    evaluator.setLabelCol("classIndex")
    evaluator.setPredictionCol("prediction")
    val accuracy = evaluator.evaluate(prediction)
    println("The model accuracy is : "   accuracy)

    // Tune model using cross validation
    val paramGrid = new ParamGridBuilder()
      .addGrid(booster.maxDepth, Array(3, 8))
      .addGrid(booster.eta, Array(0.2, 0.6))
      .build()
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(3)

    val cvModel = cv.fit(training)

    val bestModel = cvModel.bestModel.asInstanceOf[PipelineModel].stages(2)
      .asInstanceOf[XGBoostClassificationModel]
    println("The params of best XGBoostClassification model : "  
      bestModel.extractParamMap())
    println("The training summary of best XGBoostClassificationModel : "  
      bestModel.summary)

    // Export the XGBoostClassificationModel as local XGBoost model,
    // then you can load it back in local Python environment.
    bestModel.nativeBooster.saveModel(nativeModelPath)

    // ML pipeline persistence
    model.write.overwrite().save(pipelineModelPath)

    // Load a saved model and serving
    val model2 = PipelineModel.load(pipelineModelPath)
    model2.transform(test).show(false)
  }
}

pom.xml文件如下:(注意添加正确的依赖)

代码语言:javascript复制
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>ml.dmlc</groupId>
  <artifactId>xgboost4j-example_2.11</artifactId>
  <version>1.0.0</version>
  <packaging>jar</packaging>
  <name>${project.artifactId}</name>
  <description>This is a boilerplate maven project to start using Spark in Scala</description>
  <inceptionYear>2010</inceptionYear>

  <properties>
    <maven.compiler.source>1.8</maven.compiler.source>
    <maven.compiler.target>1.8</maven.compiler.target>
    <encoding>UTF-8</encoding>
    <scala.tools.version>2.11</scala.tools.version>
    <!-- Put the Scala version of the cluster --> 
    <scala.version>2.11.12</scala.version> 
    <scala.binary.version>2.11</scala.binary.version> 
    <spark.version>2.4.3</spark.version> 
  </properties>
  
  <!-- repository to add org.apache.spark -->
  <repositories>
    <repository>
      <id>cloudera-repo-releases</id>
      <url>https://repository.cloudera.com/artifactory/repo/</url>
    </repository>
    <repository>
      <id>GitHub Repo</id>
      <name>GitHub Repo</name>
      <url>https://raw.githubusercontent.com/CodingCat/xgboost/maven-repo/</url>
    </repository>
  </repositories>

  <build>
    <sourceDirectory>src/main/scala</sourceDirectory>
    <testSourceDirectory>src/test/scala</testSourceDirectory>
    <plugins>
      <plugin>
        <!-- see http://davidb.github.com/scala-maven-plugin -->
        <!-- https://mvnrepository.com/artifact/net.alchim31.maven/scala-maven-plugin -->
        <groupId>net.alchim31.maven</groupId>
        <artifactId>scala-maven-plugin</artifactId>
        <!-- <version>3.1.3</version> -->
        <version>4.0.2</version>
        <!-- <version>3.4.6</version> -->
        <executions>
          <execution>
            <goals>
              <goal>compile</goal>
              <goal>testCompile</goal>
            </goals>
            <configuration>
              <args>
                <arg>-dependencyfile</arg>
                <arg>${project.build.directory}/.scala_dependencies</arg>
              </args>
            </configuration>
          </execution>
        </executions>
      </plugin>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-surefire-plugin</artifactId>
        <version>2.13</version>
        <configuration>
          <useFile>false</useFile>
          <disableXmlReport>true</disableXmlReport>
          <!-- If you have classpath issue like NoDefClassError,... -->
          <!-- useManifestOnlyJar>false</useManifestOnlyJar -->
          <includes>
            <include>**/*Test.*</include>
            <include>**/*Suite.*</include>
          </includes>
        </configuration>
      </plugin>

      <!-- "package" command plugin -->
      <plugin>
        <artifactId>maven-assembly-plugin</artifactId>
        <version>2.4.1</version>
        <configuration>
          <descriptorRefs>
            <descriptorRef>jar-with-dependencies</descriptorRef>
          </descriptorRefs>
        </configuration>
        <executions>
          <execution>
            <id>make-assembly</id>
            <phase>package</phase>
            <goals>
              <goal>single</goal>
            </goals>
          </execution>
        </executions>
      </plugin>
    </plugins>
  </build>

  <dependencies>
    <dependency>
        <groupId>ml.dmlc</groupId>
        <artifactId>xgboost4j-spark</artifactId> 
        <version>0.90</version> 
    </dependency>
    <dependency>
        <groupId>ml.dmlc</groupId>
        <artifactId>xgboost4j</artifactId> 
        <version>0.90</version> 
    </dependency>
    <!-- Scala and Spark dependencies -->
    <!-- https://mvnrepository.com/artifact/org.scala-lang/scala-library -->
    <dependency>
      <groupId>org.scala-lang</groupId>
      <artifactId>scala-library</artifactId>
      <version>${scala.version}</version>
    </dependency>
    <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core -->
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_${scala.binary.version}</artifactId>
      <version>${spark.version}</version>
    </dependency>
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-mllib_${scala.binary.version}</artifactId>
      <version>${spark.version}</version>
    </dependency>
    <dependency>
      <groupId>org.apache.commons</groupId>
      <artifactId>commons-lang3</artifactId>
      <version>3.4</version>
    </dependency>
    <dependency>
      <groupId>org.apache.velocity</groupId>
      <artifactId>velocity</artifactId>
      <version>1.7</version>
    </dependency>
    <dependency>
      <groupId>commons-logging</groupId>
      <artifactId>commons-logging</artifactId>
      <version>1.2</version>
    </dependency>
    <!-- https://mvnrepository.com/artifact/com.github.scopt/scopt_2.11 -->
    <dependency>
      <groupId>com.github.scopt</groupId>
      <artifactId>scopt_2.11</artifactId>
      <version>3.5.0</version>
    </dependency>
  </dependencies>
</project>

之后运行生成jar包:

代码语言:javascript复制
mvn clean package

最后,在集群上提交任务即可:

代码语言:javascript复制
spark-2.4.3-bin-hadoop2.7/bin/spark-submit  --class ml.dmlc.xgboost4j.scala.example.spark.SparkMLlibPipeline --jars /***/scala_workSpace/test/xgboost4j-example_2.11-1.0.0-jar-with-dependencies.jar /***/scala_workSpace/test/xgboost4j-example_2.11-1.0.0.jar /tmp/rd/lp/iris.data /***/scala_workSpace/test/nativeModel /tmp/rd/lp/pipelineModel

0 人点赞