一. 自定义 UDF 函数
在Shell窗口中可以通过spark.udf功能用户可以自定义函数。
代码语言:javascript复制scala> val df = spark.read.json("examples/src/main/resources/people.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]
scala> df.show
---- -------
| age| name|
---- -------
|null|Michael|
| 30| Andy|
| 19| Justin|
---- -------
// 注册一个 udf 函数: toUpper是函数名, 第二个参数是函数的具体实现
scala> spark.udf.register("toUpper", (s: String) => s.toUpperCase)
res1: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
scala> df.createOrReplaceTempView("people")
scala> spark.sql("select toUpper(name), age from people").show
----------------- ----
|UDF:toUpper(name)| age|
----------------- ----
| MICHAEL|null|
| ANDY| 30|
| JUSTIN| 19|
----------------- ----
二. 用户自定义聚合函数
强类型的Dataset
和弱类型的DataFrame
都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数
2.1 弱类型UDF(求和)
- 1.源码
package com.buwenbuhuo.spark.sql.day01.udf
import com.buwenbuhuo.spark.sql.day01.Pelple
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import scala.collection.immutable.Nil
/**
**
*
* @author 不温卜火
* *
* @create 2020-08-03 12:13
**
* MyCSDN : https://buwenbuhuo.blog.csdn.net/
*
*/
object UDAFDemo {
def main(args: Array[String]): Unit = {
// 在sql中,聚合函数如何使用
val spark: SparkSession = SparkSession.builder()
.appName("UDAFDemo")
.master("local[2]")
.getOrCreate()
import spark.implicits._
val df: DataFrame = spark.read.json("d:/users.json")
df.createOrReplaceTempView("user")
// 注册聚合函数
spark.udf.register("mySum",new MySum)
spark.sql("select mySum(age) from user").show
spark.close()
}
}
class MySum extends UserDefinedAggregateFunction {
// 用来定义输入的数据类型 10.1 12.2 100
override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil)
// 缓冲区的类型
override def bufferSchema: StructType = StructType(StructField("sum",DoubleType)::Nil)
// 最终聚合结果的类型
override def dataType: DataType = DoubleType
// 相同的输入是否返回相同的输出
override def deterministic: Boolean = true
// 对缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// 在缓冲区集合中初始化和
buffer(0) = 0D // 等价于buffer.update(0,0D)
}
// 分区内聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// input是指的使用聚合函数的时候,缓过来的参数封装到了Row
if(!input.isNullAt(0)){
// 考虑到传字段可能是null
val v: Double = input.getAs[Double](0) // getDouble(0)
buffer(0) = buffer.getDouble(0) v
}
}
// 分区间的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 把buffer1和buffer2 的缓冲聚合到一起,然后再把值写回到buffer1
buffer1(0) = buffer1.getDouble(0) buffer2.getDouble(0)
}
// 返回最初的输出值
override def evaluate(buffer: Row): Any = buffer.getDouble(0)
}
- 2. 运行结果
2.2 弱类型UDF(求均值)
- 1. 源码
package com.buwenbuhuo.spark.sql.day01.udf
import com.buwenbuhuo.spark.sql.day01.Pelple
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import scala.collection.immutable.Nil
/**
**
*
* @author 不温卜火
* *
* @create 2020-08-03 12:13
**
* MyCSDN : https://buwenbuhuo.blog.csdn.net/
*
*/
object UDAFDemo1 {
def main(args: Array[String]): Unit = {
// 在sql中,聚合函数如何使用
val spark: SparkSession = SparkSession.builder()
.appName("UDAFDemo1")
.master("local[2]")
.getOrCreate()
import spark.implicits._
val df: DataFrame = spark.read.json("d:/users.json")
df.createOrReplaceTempView("user")
// 注册聚合函数
spark.udf.register("myAvg",new MyAvg)
spark.sql("select myAvg(age) from user").show
spark.close()
}
}
class MyAvg extends UserDefinedAggregateFunction {
// 用来定义输入的数据类型 10.1 12.2 100
override def inputSchema: StructType = StructType(StructField("ele",DoubleType)::Nil)
// 缓冲区的类型
override def bufferSchema: StructType =
StructType(StructField("sum",DoubleType)::StructField("count",LongType)::Nil)
// 最终聚合结果的类型
override def dataType: DataType = DoubleType
// 相同的输入是否返回相同的输出
override def deterministic: Boolean = true
// 对缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// 在缓冲区集合中初始化和
buffer(0) = 0D // 等价于buffer.update(0,0D)
buffer(1) = 0L
}
// 分区内聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// input是指的使用聚合函数的时候,缓过来的参数封装到了Row
if(!input.isNullAt(0)){
// 考虑到传字段可能是null
val v: Double = input.getAs[Double](0) // getDouble(0)
buffer(0) = buffer.getDouble(0) v
buffer(1) = buffer.getLong(1) 1L
}
}
// 分区间的聚合
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 把buffer1和buffer2 的缓冲聚合到一起,然后再把值写回到buffer1
buffer1(0) = buffer1.getDouble(0) buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) buffer2.getLong(1)
}
// 返回最初的输出值
override def evaluate(buffer: Row): Any = buffer.getDouble(0)/buffer.getLong(1)
}
- 2. 运行结果
2.3 强类型UDF(求均值)
- 1. 源码
package com.buwenbuhuo.spark.sql.day01.udf
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession, TypedColumn}
import org.apache.spark.sql.expressions.Aggregator
/**
**
*
* @author 不温卜火
* *
* @create 2020-08-03 12:43
**
* MyCSDN : https://buwenbuhuo.blog.csdn.net/
*
*/
case class Dog(name:String,age:Int)
case class AgeAvg(sum:Int,count:Int){
def avg = sum.toDouble / count
}
object UDAFDemo3 {
def main(args: Array[String]): Unit = {
// 在sql中,聚合函数如何使用
val spark: SparkSession = SparkSession.builder()
.appName("UDAFDemo3")
.master("local[2]")
.getOrCreate()
import spark.implicits._
val ds: Dataset[Dog] = List(Dog("大黄", 6), Dog("小黄", 2), Dog("中黄", 4)).toDS()
// 强类型的使用方式
val avg: TypedColumn[Dog, Double] = new MyAvg2().toColumn.name("avg")
val result: Dataset[Double] = ds.select(avg)
result.show()
spark.close()
}
}
class MyAvg2 extends Aggregator[Dog,AgeAvg,Double]{
// 对缓冲区进行初始化
override def zero: AgeAvg = AgeAvg(0,0)
// 聚合(分区内聚合)
override def reduce(b: AgeAvg, a: Dog): AgeAvg = a match {
// 如果是dog对象,则把年龄相加,个数加1
case Dog(name,age) => AgeAvg(b.sum age , b.count 1)
// 如果是null,则原封不动返回
case _ => b
}
// 分区间的聚合
override def merge(b1: AgeAvg, b2: AgeAvg): AgeAvg = {
AgeAvg(b1.sum b2.sum,b1.count b2.count)
}
// 返回最终的值
override def finish(reduction: AgeAvg): Double = reduction.avg
// 对缓冲区进行编码
override def bufferEncoder: Encoder[AgeAvg] = Encoders.product // 如果是样例,就直接返回这个编码器就行了
//对返回值进行编码
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
/*
强类型UDF
*/
- 2. 运行结果
本次的分享就到这里了