Arrow HashJoin限制

2023-09-02 10:42:09 浏览数 (1)

Arrow HashJoin限制

0.背景

最近在测试一些大数据量的HashJoin计算,例如:用户层设置batch数为600000,那么会导致crash。本节将会通过调试,来一步步学习SwissJoin(HashJoin的内部实现)的分区逻辑。

1.crash 点

本次crash点为SwissJoin分区排序。

代码语言:javascript复制
// crash point
ARROW_DCHECK(num_rows > 0 && num_rows <= (1 << 15));

为了研究这个值,于是写下个这篇文章。

2.布隆过滤器

自Arrow 9.0版本后,新增HashJoin的BF功能,在执行计划上新增PrepareToProduce接口,该接口会在每个节点开始Produce之前做一些准备工作,目前只用在HashJoin节点,这里会对BloomFilter的上下文做一些初始化,例如:线程数、调度器、注册build端finish任务、probe端任务等等。

值得注意的是,对于HashJoin的线程数为:CPU IO 1

代码语言:javascript复制
size_t num_threads = (GetCpuThreadPoolCapacity()   io::GetIOThreadPoolCapacity()   1);

随后build端StartProducing->BuildHashTable,依次继续下面的任务,BuildTask->MergeTask->ScanTask。

代码语言:javascript复制
  void InitTaskGroups() {
    task_group_build_ = scheduler_->RegisterTaskGroup(
        [this](size_t thread_index, int64_t task_id) -> Status {
          return BuildTask(thread_index, task_id);
        },
        [this](size_t thread_index) -> Status { return BuildFinished(thread_index); });
    task_group_merge_ = scheduler_->RegisterTaskGroup(
        [this](size_t thread_index, int64_t task_id) -> Status {
          return MergeTask(thread_index, task_id);
        },
        [this](size_t thread_index) -> Status { return MergeFinished(thread_index); });
    task_group_scan_ = scheduler_->RegisterTaskGroup(
        [this](size_t thread_index, int64_t task_id) -> Status {
          return ScanTask(thread_index, task_id);
        },
        [this](size_t thread_index) -> Status { return ScanFinished(thread_index); });
  }

3.BuildTask

  • 哈希计算

在这个里面会去做一些指令集的Hash。

  • 数据分区:

在数据Join时,通常会将具有相同哈希值的元素分配到同一个分区中,以便进行后续的连接操作。对于每个哈希值,这段代码会计算其所属的分区,并将其对应的行索引保存在 locals.batch_prtn_row_ids 中。分区的范围和索引信息保存在 locals.batch_prtn_ranges 中。

  • 更新哈希

在完成数据分区后,代码会对 locals.batch_hashes 中的哈希值进行更新,以清除已经用于分区的高位比特位。这样做是为了后续将哈希值用于建立哈希表时能够直接使用这些低位比特位,从而减少哈希冲突的可能性。

取高位计算分区:

代码语言:javascript复制
return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1;

移除已经参与分区计算的高位:

代码语言:javascript复制
 locals.batch_hashes[i] <<= log_num_prtns_;
  • 数据插入

它会根据分区信息从 key_batchpayload_batch_maybe_null 中选择相应的数据,并将其插入到该分区对应的哈希表中。

4.数据分区

4.1 分区逻辑

以代码中注释为例:

代码语言:javascript复制
 输入数组: [5, 7, 2, 3, 5, 4]
 分区数: 3
 分区算法: [&in_arr] (int row_id) { return in_arr[row_id] / 3; }
 输出位置映射算法: [&out_arr] (int row_id, int pos) { out_arr[pos] = row_id; }

执行分区操作后,我们的到:out_arr: [2, 5, 3, 5, 4, 7] prtn_ranges: [0, 1, 5, 6]

  • out_arr执行的流程为:

按照分区算法,对数组每个元素除以3的到分区数组[1, 2, 0, 1, 1, 1],排序的到[0, 1, 1, 1, 1, 2],随后映射到输入数组的每一条记录上,便得到了[2, 5, 3, 5, 4, 7] 。

  • prtn_ranges执行的流程为:

我们在上面已经得到了[0, 1, 1, 1, 1 2],那么就可以计算每个分区在排序后数组中的起始位置,便得到了[0, 1, 5, 6]。

SwissJoin里面的实际分区算法:右移31-log_num_prtns_,以获取哈希值的高 log_num_prtns_ 位作为分区 ID,最后右移1位向下取整。

代码语言:javascript复制
return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1;

4.2 分区crash调试流程

为了方便查看上下文,这里通过gdb详细的打印了一些必要信息。

应用层传入13个batch,总行数1500036,依次batch数量为120254、120414、119563等等。

代码语言:javascript复制
(gdb) p build_side_batches_
$21 = {row_count_ = 1500036, batches_ = std::vector of length 13, capacity 16 = {{values = std::vector of length 2, capacity 2 = {{
          static kUnknownLength = -1, 
          value = ...
(gdb) p build_side_batches_.batches_[0].ToString ()
$24 = "ExecBatchn    # Rows: 120254n    0: Array[1,8,11,13,18,27,30,32,36,40,...,599964,599968,599972,599974,599984,599985,599990,599991,599992,599999]n    1: Array[1,8,11,13,18,27,30,32,36,40,...,599964,599"...
(gdb) p build_side_batches_.batches_[1].ToString ()
$25 = "ExecBatchn    # Rows: 120414n    0: Array[600011,600021,600028,600029,600032,600035,600037,600044,600048,600051,...,1199968,1199970,1199972,1199973,1199978,1199988,1199990,1199996,1199997,1199999]n   "...
(gdb) p build_side_batches_.batches_[2].ToString ()
$26 = "ExecBatchn    # Rows: 119563n    0: Array[1200001,1200012,1200019,1200023,1200026,1200029,1200030,1200043,1200053,1200065,...,1799973,1799976,1799986,1799988,1799989,1799991,1799994,1799996,1799998,18"..
(gdb) p num_rows_
$35 = 1500036
(gdb) p dop_
$38 = 105
(gdb) p bit_util::Log2(105)
$39 = 7
(gdb) p bit_util::Log2(num_rows_ / (1 << 18))
$41 = 3
(gdb) p num_prtns_
$42 = 8
(gdb) p thread_states_[0]
$47 = {batch_hashes = std::vector of length 120254, capacity 120254 = {2977511326, 2362293233, 2688180940, 53399816, 2072962591, 3117799855, 
    3443687562, 808906694, 4129082687, 3154357176, 536287733, 2539073292, 1238460074, 1634524813, 2679361821, 1704636310, 729910543, 
    2390031435, 1774747807, 4120198289, 3145472522, 1828082088, 3830867903, 2530254429, 1555528918, 1881416625, 3900979400, 1266198276, 
    ......

dop_可以理解为线程数,这里计算出来是105,计算来自PrepareToProduce

代码语言:javascript复制
  // TODO(ARROW-15732)
  // Each side of join might have an IO thread being called from. Once this is fixed
  // we will change it back to just the CPU's thread pool capacity.
  // 105 = 96   8   1
  size_t num_threads = (GetCpuThreadPoolCapacity()   io::GetIOThreadPoolCapacity()   1);

分区初始化:

代码语言:javascript复制
constexpr int64_t min_num_rows_per_prtn = 1 << 18;
log_num_prtns_ =
    std::min(bit_util::Log2(dop_),
             bit_util::Log2(bit_util::CeilDiv(num_rows, min_num_rows_per_prtn)));
num_prtns_ = 1 << log_num_prtns_;

所以这里是

代码语言:javascript复制
log_num_prtns_ = min(7, 3) = 3;
num_prtns_ = 1 << 3 = 8;

划分为8个分区,每个分区里面的行数限制为1 << 15,分区数也是1 << 15。

在这里我们第一个batch行数为120254,超过了1 << 15,就crash了,所以我们应该限制每个batch的数量不超过1 << 15。

5.总结

理解三个控制参数:

  • 控制分区数量不要太多 1 << 18 = 262144
  • 控制分区内的行数不会超过1 << 15 = 32768
  • 分区数不得超过 1 << 15 = 32768

也就是整表限制 <= 2^30

0 人点赞