Spark 中的 Shuffle 是什么?
Apache Spark 通过将数据分布在多个节点并在每个节点上单独计算值来处理查询。然而有时节点需要交换数据。毕竟这就是 Spark 的目的——处理单台机器无法容纳的数据。
Shuffle 是分区之间交换数据的过程。因此,当源分区和目标分区驻留在不同的计算机上时,数据行可以在工作节点之间移动。
Spark 不会在节点之间随机移动数据。Shuffle 是一项耗时的操作,因此只有在没有其他选择的情况下才会发生。
性能影响
Shuffle是一项昂贵的操作,因为它涉及磁盘I / O、数据序列化和网络 I/O。为了 Shuffle ,Spark 生成一组 map 任务来组织数据,以及一组 reduce 任务来聚合数据。这个命名来自 MapReduce,与 Spark 的 map 和 reduce 操作没有直接关系。
各个 map 任务的结果都会保存在内存中,直到它们无法容纳为止。然后根据目标分区对它们进行排序并写入单个文件。在 reduce 端,任务读取相关的排序块。
某些 Shuffle 操作可能会消耗大量堆内存,因为它们在传输之前或之后使用内存中数据结构来组织记录。Shuffle 还会在磁盘上生成大量中间文件。
最重要的部分→ 如何避免 Spark Shuffle?
- 使用适当的分区:确保您的数据从一开始就进行了适当的分区。如果您的数据已经根据您正在执行的操作进行分区,Spark 可以完全避免 Shuffle 。使用 repartition() 或 coalesce() 来控制数据的分区。
# Sample data
data = [(1, "A"), (2, "B"), (3, "C"), (4, "D"), (5, "E")]
# Create a DataFrame
df = spark.createDataFrame(data, ["id", "name"])
# Bad - Shuffling involved due to default partitioning (200 partitions)
result_bad = df.groupBy("id").count()
# Good - Avoids shuffling by explicitly repartitioning (2 partitions)
df_repartitioned = df.repartition(2, "id")
result_good = df_repartitioned.groupBy("id").count()
- 尽早过滤:在转换中尽早对数据应用过滤器或条件。这样,您可以减少后续阶段需要打乱的数据量。
# Sample data
sales_data = [(101, "Product A", 100), (102, "Product B", 150), (103, "Product C", 200)]
categories_data = [(101, "Category X"), (102, "Category Y"), (103, "Category Z")]
# Create DataFrames
sales_df = spark.createDataFrame(sales_data, ["product_id", "product_name", "price"])
categories_df = spark.createDataFrame(categories_data, ["product_id", "category"])
# Bad - Shuffling involved due to regular join
result_bad = sales_df.join(categories_df, on="product_id")
# Good - Avoids shuffling using broadcast variable
# Filter the small DataFrame early and broadcast it for efficient join
filtered_categories_df = categories_df.filter("category = 'Category X'")
result_good = sales_df.join(broadcast(filtered_categories_df), on="product_id")
- 使用广播变量:如果您有较小的查找数据想要与较大的数据集连接,请考虑使用广播变量。将小数据集广播到所有节点比混洗较大数据集更有效。
# Sample data
products_data = [(101, "Product A", 100), (102, "Product B", 150), (103, "Product C", 200)]
categories_data = [(101, "Category X"), (102, "Category Y"), (103, "Category Z")]
# Create DataFrames
products_df = spark.createDataFrame(products_data, ["product_id", "product_name", "price"])
categories_df = spark.createDataFrame(categories_data, ["category_id", "category_name"])
# Bad - Shuffling involved due to regular join
result_bad = products_df.join(categories_df, products_df.product_id == categories_df.category_id)
# Good - Avoids shuffling using broadcast variable
# Create a broadcast variable from the categories DataFrame
broadcast_categories = broadcast(categories_df)
# Join the DataFrames using the broadcast variable
result_good = products_df.join(broadcast_categories, products_df.product_id == broadcast_categories.category_id)
- 避免使用groupByKey():首选reduceByKey()或aggregateByKey(),而不是groupByKey(),因为前者在打乱数据之前在本地执行部分聚合,从而获得更好的性能。
# Sample data
data = [(1, "click"), (2, "like"), (1, "share"), (3, "click"), (2, "share")]
# Create an RDD
rdd = sc.parallelize(data)
# Bad - Shuffling involved due to groupByKey
result_bad = rdd.groupByKey().mapValues(len)
# Good - Avoids shuffling by using reduceByKey
result_good = rdd.map(lambda x: (x[0], 1)).reduceByKey(lambda a, b: a b)
- 使用数据局部性:只要有可能,尝试处理已存储在进行计算的同一节点上的数据。这减少了网络通信和Shuffle。
# Sample data
data = [(1, 10), (2, 20), (1, 5), (3, 15), (2, 25)]
# Create a DataFrame
df = spark.createDataFrame(data, ["key", "value"])
# Bad - Shuffling involved due to default data locality
result_bad = df.groupBy("key").max("value")
# Good - Avoids shuffling by repartitioning and using data locality
df_repartitioned = df.repartition("key") # Repartition to align data by key
result_good = df_repartitioned.groupBy("key").max("value")
- 使用内存和磁盘缓存:缓存将在多个阶段重用的中间数据可以帮助避免重新计算并减少Shuffle的需要。
# Sample data
data = [(1, 10), (2, 20), (1, 5), (3, 15), (2, 25)]
# Create a DataFrame
df = spark.createDataFrame(data, ["key", "value"])
# Bad - Shuffling involved due to recomputation of the filter condition
result_bad = df.filter("value > 10").groupBy("key").sum("value")
# Good - Avoids shuffling by caching the filtered data
df_filtered = df.filter("value > 10").cache()
result_good = df_filtered.groupBy("key").sum("value")
- 优化数据序列化:选择 Avro 或 Kryo 等高效的序列化格式,以减少 Shuffle过程中的数据大小。
# Create a Spark session with KryoSerializer
spark = SparkSession.builder
.appName("AvoidShuffleExample")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.getOrCreate()
- 调整Spark配置:调整Spark的配置参数,如Spark.shuffle.departitions、Spark.reducer.maxSizeInFlight和Spark.shuzzle.file.buffer。
- 监控和分析:使用Spark的监控工具,如Spark UI和Spark History Server来分析作业的性能,并确定可以优化shuffle的区域。
通过遵循这些最佳实践并优化 Spark 作业,可以显着减少 shuffle 的需要,从而提高性能和资源利用率。然而在某些情况下,shuffle 可能仍然不可避免,特别是对于复杂的操作或处理大型数据集时。在这种情况下,应重点优化而不是完全避免 shuffle 。 原文作者:Sushil Kumar