Spark SQL 项目实战 | 计算各区域热门商品 Top3

2020-10-28 17:41:20 浏览数 (1)

一. 需求

1.1 需求简介

这里的热门商品是从点击量的维度来看的.

计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。

1.2 思路分析

使用 sql 来完成. 碰到复杂的需求, 可以使用 udf 或 udaf

  1. 查询出来所有的点击记录, 并与 city_info 表连接, 得到每个城市所在的地区. 与 Product_info 表连接得到产品名称
  2. 按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
  3. 每个地区内按照点击次数降序排列
  4. 只取前三名. 并把结果保存在数据库中
  5. 城市备注需要自定义 UDAF 函数

二. 实际操作

1. 准备数据

  我们这次 Spark-sql 操作中所有的数据均来自 Hive.

  首先在 Hive 中创建表, 并导入数据.

  一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表

  • 1. 打开Hive
  • 2. 创建三个表
代码语言:javascript复制
CREATE TABLE `user_visit_action`(
  `date` string,
  `user_id` bigint,
  `session_id` string,
  `page_id` bigint,
  `action_time` string,
  `search_keyword` string,
  `click_category_id` bigint,
  `click_product_id` bigint,
  `order_category_ids` string,
  `order_product_ids` string,
  `pay_category_ids` string,
  `pay_product_ids` string,
  `city_id` bigint)
row format delimited fields terminated by 't';

CREATE TABLE `product_info`(
  `product_id` bigint,
  `product_name` string,
  `extend_info` string)
row format delimited fields terminated by 't';

CREATE TABLE `city_info`(
  `city_id` bigint,
  `city_name` string,
  `area` string)
row format delimited fields terminated by 't';
  • 3. 上传数据
代码语言:javascript复制
load data local inpath '/opt/module/datas/user_visit_action.txt' into table spark0806.user_visit_action;
load data local inpath '/opt/module/datas/product_info.txt' into table spark0806.product_info;
load data local inpath '/opt/module/datas/city_info.txt' into table spark0806.city_info;
  • 4. 测试是否上传成功
代码语言:javascript复制
hive> select * from city_info;

2. 显示各区域热门商品 Top3

  • 1. 源码
代码语言:javascript复制
// user_visit_action  product_info  city_info

1. 先把需要的字段查出来   t1
select
    ci.*,
    pi.product_name,
    click_product_id
from user_visit_action uva
join product_info pi on uva.click_product_id=pi.product_id
join city_info ci on uva.city_id=ci.city_id

2. 按照地区和商品名称聚合
select
    area,
    product_name,
    count(*)  count
from t1
group by area , product_name

3. 按照地区进行分组开窗 排序 开窗函数 t3 // (rank(1 2 2 4 5...) row_number(1 2 3 4...) dense_rank(1 2 2 3 4...))
select
    area,
    product_name,
    count,
    rank() over(partition by area order by count desc)
from  t2


4. 过滤出来名次小于等于3的
select
    area,
    product_name,
    count
from  t3
where rk <=3
  • 2. 运行结果

3. 定义udaf函数 得到需求结果

  • 1. 源码
代码语言:javascript复制
package com.buwenbuhuo.spark.sql.project

import java.text.DecimalFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-06 13:24
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
class CityRemarkUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
  }

  // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
  }

  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType

  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true

  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }

  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
        // 1. 总的点击量   1
        buffer(1) = buffer.getLong(1)   1L
        // 2. 给这个城市的点击量  1 =>   找到缓冲区的map,取出来这个城市原来的点击   1 ,再复制过去
        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
        buffer(0) = map   (cityName -> (map.getOrElse(cityName, 0L)   1L))
      case _ =>
    }
  }

  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)

    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)

    // 1. 总数的聚合
    buffer1(1) = total1   total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
        map   (cityName -> (map.getOrElse(cityName, 0L)   count))
    }

  }

  // 最终的输出结果
  override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val total: Long = buffer.getLong(1)

    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List[CityRemark] = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_ _.cityRatio))
    cityRemarks : = CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
  }
}

case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")

  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"

}
  • 运行结果

4 .保存到Mysql

  • 1. 源码
代码语言:javascript复制
    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark
        |from t3
        |where rk<=3
        |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)
  • 2.运行结果

三. 完整代码

  • 1. udaf
代码语言:javascript复制
package com.buwenbuhuo.spark.sql.project

import java.text.DecimalFormat

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-06 13:24
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
class CityRemarkUDAF extends UserDefinedAggregateFunction {
  // 输入数据的类型:  北京  String
  override def inputSchema: StructType = {
    StructType(Array(StructField("city", StringType)))
  }

  // 缓存的数据的类型 每个地区的每个商品 缓冲所有城市的点击量 北京->1000, 天津->5000  Map,  总的点击量  1000/?
  override def bufferSchema: StructType = {
    StructType(Array(StructField("map", MapType(StringType, LongType)), StructField("total", LongType)))
  }

  // 输出的数据类型  "北京21.2%,天津13.2%,其他65.6%"  String
  override def dataType: DataType = StringType

  // 相同的输入是否应用有相同的输出.
  override def deterministic: Boolean = true

  // 给存储数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化map缓存
    buffer(0) = Map[String, Long]()
    // 初始化总的点击量
    buffer(1) = 0L
  }

  // 分区内合并 Map[城市名, 点击量]
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    input match {
      case Row(cityName: String) =>
        // 1. 总的点击量   1
        buffer(1) = buffer.getLong(1)   1L
        // 2. 给这个城市的点击量  1 =>   找到缓冲区的map,取出来这个城市原来的点击   1 ,再复制过去
        val map: collection.Map[String, Long] = buffer.getMap[String, Long](0)
        buffer(0) = map   (cityName -> (map.getOrElse(cityName, 0L)   1L))
      case _ =>
    }
  }

  // 分区间的合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val map1 = buffer1.getAs[Map[String, Long]](0)
    val map2 = buffer2.getAs[Map[String, Long]](0)

    val total1: Long = buffer1.getLong(1)
    val total2: Long = buffer2.getLong(1)

    // 1. 总数的聚合
    buffer1(1) = total1   total2
    // 2. map的聚合
    buffer1(0) = map1.foldLeft(map2) {
      case (map, (cityName, count)) =>
        map   (cityName -> (map.getOrElse(cityName, 0L)   count))
    }

  }

  // 最终的输出结果
  override def evaluate(buffer: Row): Any = {
    // "北京21.2%,天津13.2%,其他65.6%"
    val cityAndCount: collection.Map[String, Long] = buffer.getMap[String, Long](0)
    val total: Long = buffer.getLong(1)

    val cityCountTop2: List[(String, Long)] = cityAndCount.toList.sortBy(-_._2).take(2)
    var cityRemarks: List[CityRemark] = cityCountTop2.map {
      case (cityName, count) => CityRemark(cityName, count.toDouble / total)
    }
//    CityRemark("其他",1 - cityremarks.foldLeft(0D)(_ _.cityRatio))
    cityRemarks : = CityRemark("其他",cityRemarks.foldLeft(1D)(_ - _.cityRatio))
    cityRemarks.mkString(",")
  }
}

case class CityRemark(cityName: String, cityRatio: Double) {
  val formatter = new DecimalFormat("0.00%")

  override def toString: String = s"$cityName:${formatter.format(cityRatio)}"

}
  • 2. 主程序(具体实现)
代码语言:javascript复制
package com.buwenbuhuo.spark.sql.project

import java.util.Properties

import org.apache.spark.sql.SparkSession

/**
 **
 *
 * @author 不温卜火
 *         *
 * @create 2020-08-05 19:01
 **
 *         MyCSDN :  https://buwenbuhuo.blog.csdn.net/
 *
 */
object SqlApp {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .master("local[*]")
      .appName("SqlApp")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    spark.udf.register("remark",new CityRemarkUDAF)

    // 去执行sql,从hive查询数据
    spark.sql("use spark0806")
    spark.sql(
      """
        |select
        |    ci.*,
        |    pi.product_name,
        |    uva.click_product_id
        |from user_visit_action uva
        |join product_info pi on uva.click_product_id=pi.product_id
        |join city_info ci on uva.city_id=ci.city_id
        |
        |""".stripMargin).createOrReplaceTempView("t1")

    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count(*) count,
        |    remark(city_name) remark
        |from t1
        |group by area, product_name
        |""".stripMargin).createOrReplaceTempView("t2")

    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark,
        |    rank() over(partition by area order by count desc) rk
        |from t2
        |""".stripMargin).createOrReplaceTempView("t3")

    val props: Properties = new Properties()
    props.put("user","root")
    props.put("password","199712")
    spark.sql(
      """
        |select
        |    area,
        |    product_name,
        |    count,
        |    remark
        |from t3
        |where rk<=3
        |""".stripMargin)
      .coalesce(1)
      .write
      .mode("overwrite")
      .jdbc("jdbc:mysql://hadoop002:3306/rdd?useUnicode=true&characterEncoding=utf8", "spark0806", props)


    // 把结果写入到mysql中

    spark.close()
  }
}

  本次的分享就到这里了

0 人点赞