Spark SQL 快速入门系列(6) | 一文教你如何自定义 SparkSQL 函数

2020-10-28 17:42:29 浏览数 (1)

一. 自定义 UDF 函数

  在Shell窗口中可以通过spark.udf功能用户可以自定义函数。

代码语言:javascript复制
scala> val df = spark.read.json("examples/src/main/resources/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]

scala> df.show
 ---- ------- 
| age|   name|
 ---- ------- 
|null|Michael|
|  30|   Andy|
|  19| Justin|
 ---- ------- 
// 注册一个 udf 函数: toUpper是函数名, 第二个参数是函数的具体实现
scala> spark.udf.register("toUpper", (s: String) => s.toUpperCase)
res1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))

scala> df.createOrReplaceTempView("people")

scala> spark.sql("select toUpper(name), age from people").show
 ----------------- ---- 
|UDF:toUpper(name)| age|
 ----------------- ---- 
|          MICHAEL|null|
|             ANDY|  30|
|           JUSTIN|  19|
 ----------------- ---- 

二. 用户自定义聚合函数

强类型的Dataset弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数

2.1 弱类型UDF(求和)

  • 1.源码
代码语言:javascript复制
package com.buwenbuhuo.spark.sql.day01.udf

import com.buwenbuhuo.spark.sql.day01.Pelple
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

import scala.collection.immutable.Nil

/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-03 12:13
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
object UDAFDemo {
  def main(args: Array[String]): Unit = {
    // 在sql中,聚合函数如何使用
    val spark: SparkSession = SparkSession.builder()
      .appName("UDAFDemo")
      .master("local[2]")
      .getOrCreate()
    import spark.implicits._
    val df: DataFrame = spark.read.json("d:/users.json")
    df.createOrReplaceTempView("user")
    // 注册聚合函数
    spark.udf.register("mySum",new MySum)
    spark.sql("select mySum(age) from user").show

    spark.close()
  }
}

class MySum extends UserDefinedAggregateFunction {

  // 用来定义输入的数据类型  10.1 12.2 100
  override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil)

  // 缓冲区的类型
  override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::Nil)

  // 最终聚合结果的类型
  override def dataType: DataType = DoubleType

  // 相同的输入是否返回相同的输出
  override def deterministic: Boolean = true

  // 对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 在缓冲区集合中初始化和
    buffer(0) = 0D  // 等价于buffer.update(0,0D)

  }

  // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // input是指的使用聚合函数的时候,缓过来的参数封装到了Row
    if(!input.isNullAt(0)){
      // 考虑到传字段可能是null
      val v: Double = input.getAs[Double](0)  // getDouble(0)
      buffer(0) = buffer.getDouble(0)   v
    }
  }

  // 分区间的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 把buffer1和buffer2 的缓冲聚合到一起,然后再把值写回到buffer1
    buffer1(0) = buffer1.getDouble(0)   buffer2.getDouble(0)
  }

  // 返回最初的输出值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}
  • 2. 运行结果

2.2 弱类型UDF(求均值)

  • 1. 源码
代码语言:javascript复制
package com.buwenbuhuo.spark.sql.day01.udf

import com.buwenbuhuo.spark.sql.day01.Pelple
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

import scala.collection.immutable.Nil

/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-03 12:13
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
object UDAFDemo1 {
  def main(args: Array[String]): Unit = {
    // 在sql中,聚合函数如何使用
    val spark: SparkSession = SparkSession.builder()
      .appName("UDAFDemo1")
      .master("local[2]")
      .getOrCreate()
    import spark.implicits._
    val df: DataFrame = spark.read.json("d:/users.json")
    df.createOrReplaceTempView("user")
    // 注册聚合函数
    spark.udf.register("myAvg",new MyAvg)
    spark.sql("select myAvg(age) from user").show

    spark.close()
  }
}

class MyAvg extends UserDefinedAggregateFunction {

  // 用来定义输入的数据类型  10.1 12.2 100
  override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil)

  // 缓冲区的类型
  override def bufferSchema: StructType =
    StructType(StructField("sum",DoubleType)::StructField("count",LongType)::Nil)

  // 最终聚合结果的类型
  override def dataType: DataType = DoubleType

  // 相同的输入是否返回相同的输出
  override def deterministic: Boolean = true

  // 对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 在缓冲区集合中初始化和
    buffer(0) = 0D  // 等价于buffer.update(0,0D)
    buffer(1) = 0L

  }

  // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // input是指的使用聚合函数的时候,缓过来的参数封装到了Row
    if(!input.isNullAt(0)){
      // 考虑到传字段可能是null
      val v: Double = input.getAs[Double](0)  // getDouble(0)
      buffer(0) = buffer.getDouble(0)   v
      buffer(1) = buffer.getLong(1)   1L
    }
  }

  // 分区间的聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 把buffer1和buffer2 的缓冲聚合到一起,然后再把值写回到buffer1
    buffer1(0) = buffer1.getDouble(0)   buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1)   buffer2.getLong(1)
  }

  // 返回最初的输出值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0)/buffer.getLong(1)
}
  • 2. 运行结果

2.3 强类型UDF(求均值)

  • 1. 源码
代码语言:javascript复制
package com.buwenbuhuo.spark.sql.day01.udf


import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession, TypedColumn}
import org.apache.spark.sql.expressions.Aggregator


/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-03 12:43
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
case class Dog(name:String,age:Int)

case class AgeAvg(sum:Int,count:Int){
  def avg = sum.toDouble / count
}

object UDAFDemo3 {
  def main(args: Array[String]): Unit = {
    // 在sql中,聚合函数如何使用
    val spark: SparkSession = SparkSession.builder()
      .appName("UDAFDemo3")
      .master("local[2]")
      .getOrCreate()
    import spark.implicits._
    val ds: Dataset[Dog] = List(Dog("大黄", 6), Dog("小黄", 2), Dog("中黄", 4)).toDS()
    // 强类型的使用方式
    val avg: TypedColumn[Dog, Double] = new MyAvg2().toColumn.name("avg")
    val result: Dataset[Double] = ds.select(avg)
    result.show()

    spark.close()

  }
}
class MyAvg2 extends Aggregator[Dog,AgeAvg,Double]{

  // 对缓冲区进行初始化
  override def zero: AgeAvg = AgeAvg(0,0)

  // 聚合(分区内聚合)
  override def reduce(b: AgeAvg, a: Dog): AgeAvg = a match {
    // 如果是dog对象,则把年龄相加,个数加1
    case Dog(name,age) => AgeAvg(b.sum   age , b.count   1)
      // 如果是null,则原封不动返回
    case _ => b
  }

  // 分区间的聚合
  override def merge(b1: AgeAvg, b2: AgeAvg): AgeAvg = {
    AgeAvg(b1.sum   b2.sum,b1.count   b2.count)
  }

  // 返回最终的值
  override def finish(reduction: AgeAvg): Double = reduction.avg

  // 对缓冲区进行编码
  override def bufferEncoder: Encoder[AgeAvg] = Encoders.product // 如果是样例,就直接返回这个编码器就行了

  //对返回值进行编码
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
/*
强类型UDF
 */
  • 2. 运行结果

  本次的分享就到这里了

0 人点赞