Spark ML 正则化 标准化 归一化 ---- spark 中的正则化

2021-12-06 15:45:58 浏览数 (1)

文章大纲

  • spark 中的正则化
    • Normalizer
    • 源代码
  • 参考文献

spark 中的正则化

Normalizer

标准化文档:

  • http://spark.apache.org/docs/latest/api/scala/org/apache/spark/ml/feature/Normalizer.html

标准化源代码:

  • https://github.com/apache/spark/blob/v3.1.2/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala

文档中就这么一句话:

Normalize a vector to have unit norm using the given p-norm.

使用给定的p-范数规范化向量,使其具有单位范数。

源代码

代码语言:javascript复制
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.sql.types._

/**
 * Normalize a vector to have unit norm using the given p-norm.
 */
@Since("1.4.0")
class Normalizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
  extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable {

  @Since("1.4.0")
  def this() = this(Identifiable.randomUID("normalizer"))

  /**
   * Normalization in L^p^ space. Must be greater than equal to 1.
   * (default: p = 2)
   * @group param
   */
  @Since("1.4.0")
  val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1))

  setDefault(p -> 2.0)

  /** @group getParam */
  @Since("1.4.0")
  def getP: Double = $(p)

  /** @group setParam */
  @Since("1.4.0")
  def setP(value: Double): this.type = set(p, value)

  override protected def createTransformFunc: Vector => Vector = {
    val normalizer = new feature.Normalizer($(p))
    vector => normalizer.transform(OldVectors.fromML(vector)).asML
  }

  override protected def validateInputType(inputType: DataType): Unit = {
    require(inputType.isInstanceOf[VectorUDT],
      s"Input type must be ${(new VectorUDT).catalogString} but got ${inputType.catalogString}.")
  }

  override protected def outputDataType: DataType = new VectorUDT()

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    var outputSchema = super.transformSchema(schema)
    if ($(inputCol).nonEmpty && $(outputCol).nonEmpty) {
      val size = AttributeGroup.fromStructField(schema($(inputCol))).size
      if (size >= 0) {
        outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
          $(outputCol), size)
      }
    }
    outputSchema
  }

  @Since("3.0.0")
  override def toString: String = {
    s"Normalizer: uid=$uid, p=${$(p)}"
  }
}

@Since("1.6.0")
object Normalizer extends DefaultParamsReadable[Normalizer] {

  @Since("1.6.0")
  override def load(path: String): Normalizer = super.load(path)
}

参考文献

系列文章:

  • 正则化、标准化、归一化基本概念简介
  • spark 中的正则化
  • spark 中的标准化
  • spark 中的归一化
  • 扩展spark 的归一化函数

spark 中的 特征相关内容处理的文档

  • http://spark.apache.org/docs/latest/api/scala/org/apache/spark/ml/feature/index.html

概念简介

  • https://blog.csdn.net/u014381464/article/details/81101551

参考:

  • https://segmentfault.com/a/1190000014042959
  • https://www.cnblogs.com/nucdy/p/7994542.html
  • https://blog.csdn.net/weixin_34117522/article/details/88875270
  • https://blog.csdn.net/xuejianbest/article/details/85779029

0 人点赞