代码语言:javascript复制
import java.sql.{Connection, DriverManager, PreparedStatement}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
/**
* 电影评分数据分析,需求如下:
* 需求1:查找电影评分个数超过50,且平均评分较高的前十部电影名称及其对应的平均评分
* 电影ID 评分个数 电影名称 平均评分 更新时间
* movie_id、rating_num、title、rating_avg、update_time
* 需求2:查找每个电影类别及其对应的平均评分
* 电影类别 电影类别平均评分 更新时间
* genre、 rating_avg 、update_time
* 需求3:查找被评分次数较多的前十部电影
* 电影ID 电影名称 电影被评分的次数 更新时间
* movie_id、title、rating_num、 update_time
*/
object MetricsAppMain {
// 文件路径
private val RATINGS_CSV_FILE_PATH = "J:\t4\FlinkCommodityRecommendationSystem-main\FlinkCommodityRecommendationSystem-main\recommendation\src\main\resources\ratings.csv"
// private val MOVIES_CSV_FILE_PATH = "D:\Users\Administrator\Desktop\exam0601\datas\movies.csv"
def main(args: Array[String]): Unit = {
// step1、创建SparkSession实例对象
val spark: SparkSession = createSparkSession(this.getClass)
import spark.implicits._
/*
分析需求可知,三个需求最终结果,需要使用事实表数据和维度表数据关联,所以先数据拉宽,再指标计算
TODO: 按照数据仓库分层理论管理数据和开发指标
- 第一层(最底层):ODS层
直接加CSV文件数据为DataFrame
- 第二层(中间层):DW层
将加载业务数据(电影评分数据)和维度数据(电影基本信息数据)进行Join关联,拉宽操作
- 第三层(最上层):DA层/APP层
依据需求开发程序,计算指标,进行存储到MySQL表
*/
// step2、【ODS层】:加载数据,CSV格式数据,文件首行为列名称
val ratingDF: DataFrame = readCsvFile(spark, RATINGS_CSV_FILE_PATH, verbose = false)
// val movieDF: DataFrame = readCsvFile(spark, MOVIES_CSV_FILE_PATH, verbose = false)
// step3、【DW层】:将电影评分数据与电影信息数据进行关联,数据拉宽操作
// val detailDF: DataFrame = joinDetail(ratingDF, movieDF)
printConsole(ratingDF)
// step4、【DA层】:按照业务需求,进行指标统计分析
computeMetric(ratingDF)
Thread.sleep(1000000)
// 应用结束,关闭资源
spark.stop()
}
/**
* 构建SparkSession实例对象,默认情况下本地模式运行
*/
def createSparkSession(clazz: Class[_], master: String = "local[2]"): SparkSession = {
SparkSession.builder()
.appName(clazz.getSimpleName.stripSuffix("$"))
.master(master)
.config("spark.sql.shuffle.partitions", "2")
.getOrCreate()
}
/**
* 读取CSV格式文本文件数据,封装到DataFrame数据集
*/
def readCsvFile(spark: SparkSession, path: String, verbose: Boolean = true): DataFrame = {
val dataframe: DataFrame = spark.read
// 设置分隔符为逗号
.option("sep", ",")
// 文件首行为列名称
.option("header", "true")
// 依据数值自动推断数据类型
.option("inferSchema", "true")
.csv(path)
if(verbose){
printConsole(dataframe)
}
// 返回数据集
dataframe
}
/**
* 按照业务需求,进行指标统计,默认情况下,结果数据打印控制台
*/
def computeMetric(dataframe: DataFrame): Unit = {
// TODO: 缓存数据
dataframe.persist(StorageLevel.MEMORY_AND_DISK)
// 需求1:查找电影评分个数超过50,且平均评分较高的前十部电影名称及其对应的平均评分
val top10FilesDF: DataFrame = top10Films(dataframe)
//printConsole(top10FilesDF)
upsertToMySQL(
top10FilesDF, //
"replace into test.rating (id, userId, productId, score, timestamp) values (null, ?, ?, ?, ?)", //
(pstmt: PreparedStatement, row: Row) => {
pstmt.setInt(1, row.getAs[Int]("userId"))
pstmt.setInt(2, row.getAs[Int]("productId"))
pstmt.setDouble(3, row.getAs[Double]("score"))
pstmt.setInt(4, row.getAs[Int]("timestamp"))
}
)
// 释放资源
dataframe.unpersist()
}
/**
* 需求:查找电影评分个数超过50,且平均评分较高的前十部电影名称及其对应的平均评分
* 电影ID 评分个数 电影名称 平均评分 更新时间
* movie_id、rating_num、title、rating_avg、update_time
*/
def top10Films(dataframe: DataFrame): DataFrame = {
import dataframe.sparkSession.implicits._
dataframe
// 添加日期字段
// .withColumn("update_time", current_timestamp())
}
/**
* 将DataFrame数据集打印控制台,显示Schema信息和前10条数据
*/
def printConsole(dataframe: DataFrame): Unit = {
// 显示Schema信息
dataframe.printSchema()
// 显示前10条数据
dataframe.show(10, truncate = false)
}
/**
* 将数据保存至MySQL表中,采用replace方式,当主键存在时,更新数据;不存在时,插入数据
* @param dataframe 数据集
* @param sql 插入数据SQL语句
* @param accept 函数,如何设置Row中每列数据到SQL语句中占位符值
*/
def upsertToMySQL(dataframe: DataFrame, sql: String,
accept: (PreparedStatement, Row) => Unit): Unit = {
// 降低分区数目,对每个分区进行操作
dataframe.coalesce(1).foreachPartition{iter =>
// step1. 加载驱动类
Class.forName("com.mysql.cj.jdbc.Driver")
// 声明变量
var conn: Connection = null
var pstmt: PreparedStatement = null
try{
// step2. 创建连接
conn = DriverManager.getConnection(
"jdbc:mysql://120.26.162.238:33306/?serverTimezone=UTC&characterEncoding=utf8&useUnicode=true",
"root",
"123456"
)
pstmt = conn.prepareStatement(sql)
// step3. 插入数据
iter.foreach{row =>
// 设置SQL语句中占位符的值
accept(pstmt, row)
// 加入批次中
pstmt.addBatch()
}
// 批量执行批次
pstmt.executeBatch()
}catch {
case e: Exception => e.printStackTrace()
}finally {
// step4. 关闭连接
if(null != pstmt) pstmt.close()
if(null != conn) conn.close()
}
}
}
}