[源码解析] 快手八卦 --- 机器学习分布式训练新思路(2)

2022-05-09 16:26:58 浏览数 (1)

源码解析 快手八卦 --- 机器学习分布式训练新思路(2)

目录

  • [源码解析] 快手八卦 --- 机器学习分布式训练新思路(2)
    • 0x00 摘要
    • 0x01 优化
      • 1.1 重叠通信和计算
      • 1.2 分桶通信和扁平化
      • 1.3 分层化通信
    • 0x02 Generic Fused Optimizer
      • 2.1 背景知识
        • 2.1.1 Tensor
        • 2.1.2 Storage
        • 2.1.3 内部实现
      • 2.2 定义
      • 2.3 打平
      • 2.4 优化
        • 2.4.1 按照存储分组
        • 2.4.2 重新排序
    • 0x03 分层化 --- 进程组
      • 3.1 设计思路
      • 3.2 生成进程组
      • 3.3 Ranks
      • 3.4 BaguaProcessGroup 定义
      • 3.5 生成 communicator
      • 3.6 使用
    • 0xFF 参考

0x00 摘要

“Bagua“ 是快手和苏黎世理工(ETH Zürich)联合开发的分布式训练框架。其专门针对分布式的场景设计特定的优化算法,实现算法和系统层面的联合优化,力图极致化分布式训练的效率。其特点是:

  • 并行性能显著提高;
  • 对网络环境更鲁棒;
  • “一键式”使用;
  • 分布式通讯算法易拓展性;
  • 可用于工业级场景大规模使用;
  • 安全、故障易排查;

本文以:

  • 快手官方公共号文章 快手八卦!突破 TensorFlow、PyTorch 并行瓶颈的开源分布式训练框架来了!
  • “bagua"论文 https://arxiv.org/pdf/2107.01499.pdf
  • “bagua"官方网站 https://tutorials.baguasys.com/
  • “bagua" 演示文档
  • 项目 GitHub 地址:https://github.com/BaguaSys/bagua

为基础来分析学习。本文介绍优化方案之中的 Fused Optimizer 和 分层通信。

前一篇链接为:

[源码解析] 快手八卦 --- 机器学习分布式训练新思路(1)

0x01 优化

现有其他框架都是针对某一个具体算法或者场景进行优化,下图是DP-SG的通信模式以及Horovod、BytePS和PyTorch-DDP如何针对这种通信模式进行优化。

八卦希望设计一种针对所有通信算法的优化方式。BAGUA的核心部分是它的执行优化器(execution optimizer)。给定一个神经网络作为输入,一个训练算法(例如QSGD)将在每个层的计算过程中利用一系列的通信原语来实现。BAGUA的执行优化器的目标是自动安排和优化这些计算和通信。在BAGUA中探索了以下技术。

1.1 重叠通信和计算

该项优化的目的是将通讯时间隐藏在计算时间中。

把通信和计算重叠起来是加速分布式DP-SG的一个核心操作。不仅限于DP-SG算法,BAGUA能够以一种灵活和自动的方式将通信原语与其他算法的计算重叠起来,因此能够将部分通信时间隐藏在计算时间中,这样可以降低通信开销。

具体来讲,在反向梯度的计算过程中,部分已经完成的梯度可以在剩余梯度的计算过程中同时进行通信——通过这种流水的处理方式,部分通信时间可以被有效地“隐藏”在反向梯度的计算过程中,从而减小数据并行带来的通信开销。BAGUA自动分析计算图,包括in-place张量操作和十个通信原语。尽管人们可以通过静态分析来构建这个图,但BAGUA利用动态分析方法,在第一次迭代中就可以收集到张量操作和通信基元的调用依赖。

与现有系统相比,BAGUA考虑了更复杂的调度。在vanilla DP-SG中,优化只能将Allreduce通信隐藏在反向传播的计算中;相比之下,BAGUA可以调度额外的元素,如使用低精度的压缩/解压缩和优化算法对于指定的模型进行更新。

1.2 分桶通信和扁平化

频繁的传输碎片化数据,会降低通信的效率,不利于充分利用网络带宽。为了有效地将通信和计算重叠起来,将各层型参数划分为若干个桶进行通信是一个必要的步骤,这样通讯的单位就变成了桶,从而能够更高效地利用通信模型。

因此,Horovod和PyTorch-DDP都采用了桶的技巧。然而,他们的bucketing方案只是简单地把Allreduce通信硬编码,用启发式的思路来减少成本,并使用神经网络之中层的倒序来确定buckets。相比之下,由于BAGUA支持更多通信方式,而且这些通信方式可以指定优化算法,并且使用BAGUA的通信原语,因此bucketing是根据在分析(profiling)阶段收集的相关性信息来确定。

一旦我们将计算图分割成桶,BAGUA就在这些桶上进行融合。这使得BAGUA有可能实现一个更有效的流水线。在确定反向传播的第一次运行中的桶的分区后,BAGUA会仔细地将桶内的参数(如模型参数、梯度和优化器状态)对齐到一个连续的内存空间。然后在所有的流水线执行中利用这种参数的扁平化视图。

此外,由于支持了信息压缩算法,对于压缩和解压的函数,其操作的基本单位也是桶,这样也能使得这些操作的开销降低。例如,低精度压缩/解压缩lambda会直接应用于桶的扁平化视图,而不是单个参数;用于模型更新的基于SG的优化器也在桶的层面上进行(NVIDIA的Apex也使用类似的优化)。请注意,这种扁平化视图可以更有效地利用计算单元所提供的并行性。

1.3 分层化通信

由于工业级别的分布式训练往往需要多机多卡,而不同物理连接方式所带来的延时和带宽也有较大差异,因此,通讯的有效抽象也对性能的提升至关重要。

BAGUA的通信可以分层进行。这在处理异构网络连接时特别有用,例如,服务器内GPU之间的带宽要比服务器之间的带宽高得多。Bagua 将涉及多机的通信抽象成:“机内”和“机间”,在此抽象的基础上优化通信基元的实现,并对于相应的通信抽象做了优化。

例如,对于信息压缩传输,分层化通讯将会把这一算法解读成“机内”完整精度,“机间”信息压缩,从而为不同的物理链接提供最合适的通信算法。集中式低精度原语(CLPS)可以被优化为首先在每个节点内部的本地工作者上聚合张量,不压缩,然后在每个节点选出的领导worker上进行节点间聚合,压缩。最后让每个领导worker在节点内广播聚合的数据。请注意,这种优化可能会改变通信原语的语义。对于去中心化的原语,节点内的工作者将总是被改变为中心化的Allreduce方式。

接下来,我们就看看两种优化手段:融合和分层化。

0x02 Generic Fused Optimizer

八卦提供了通用的融合优化器,通过在多层上融合优化器.step()操作(fusing the optimizer .step() operation on multiple layers)来提高优化器的性能。它可以应用于任意 PyTorch 优化器。代码位于 bagua/torch_api/contrib/fused_optimizer.py。

2.1 背景知识

我们首先介绍一下背景知识。

2.1.1 Tensor

我们一般印象中的 Tensor 如下:

实际上,张量分为元信息区(Tensor) 和 存储区(Storage)。信息区保存张量的形状(size),步长(stride),数据类型(type)等信息,真正数据则在 Storage 之中保存成连续数组。

代码语言:javascript复制
 ------------------          ----------------- 
| Tensor           |        | Storage         |
|                  |        |                 |
|                  |        |                 |
|    stride        |        |      data       |
|                  |        |                 |
|    size          |        |      size       |
|                  |        |                 |
|    type          |        |                 |
|                  |        |                 |
|    shape         |        |                 |
|                  |        |                 |
|    dimention     |        |                 |
|                  |        |                 |
|    storage   -----------> |                 |
|                  |        |                 |
|                  |        |                 |
 ------------------          ----------------- 
2.1.2 Storage

我们也可以这么理解,Storage 是连续的内存块,Tensor 是一个视图,该视图把Storage单条内存区域映射到了n维的空间视图。

所以涉及到几个概念。

  • Size 是张量的维度。
  • Storage offset 是数据在storage中的索引。是张量第一个元素与storage第一个元素的偏移量。
  • Stride 是storage中对应于张量相邻维度间第一个索引的跨度,是在指定维度中从一个元素跳到下一个元素所必需的步长。

比如:

代码语言:javascript复制
import torch

a = torch.arange(6)
print("Tensor a : ", a)
print("a storage : " , a.storage())
print("a size : " , a.size())
print("a stride : " , a.stride())
print("a.data.storage().data_ptr() : " , a.data.storage().data_ptr())

b = a.view(2,3) # 换一种view方式
print("Tensor b : ", b)
print("b storage : " , b.storage())
print("b size : " , b.size())
print("b stride : " , b.stride())
print("b.data.storage().data_ptr() : " , b.data.storage().data_ptr())

c = a.view(3,2) # 再换一种view方式
print("Tensor c : ", c)
print("c storage : " , c.storage())
print("c size : " , c.size())
print("c stride : " , c.stride())
print("c.data.storage().data_ptr() : " , c.data.storage().data_ptr())

输出,可以看出来,同样的存储,但是视图不同,就是不同的张量:

代码语言:javascript复制
# 张量 a
Tensor a :  tensor([0, 1, 2, 3, 4, 5])
a storage :   
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]
a size :  torch.Size([6])
a stride :  (1,)
a.data.storage().data_ptr() :  140266160612352
  
# 张量 b  
Tensor b :  tensor([[0, 1, 2],
        [3, 4, 5]])
b storage :   
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]
b size :  torch.Size([2, 3])
b stride :  (3, 1)
b.data.storage().data_ptr() :  140266160612352
  
# 张量 c  
Tensor c :  tensor([[0, 1],
        [2, 3],
        [4, 5]])
c storage :   
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]
c size :  torch.Size([3, 2])
c stride :  (2, 1)
c.data.storage().data_ptr() :  140266160612352

我们单独看看 offset

代码语言:javascript复制
d = a[3:]
print(d.storage())
print(a.storage_offset())
print(b.storage_offset())
print(c.storage_offset())
print(d.storage_offset())

输出如下,可以看出来,d 的 storage 不变,但是d 的 torage_offset 是 3 :

代码语言:javascript复制
# d的storae
 0
 1
 2
 3
 4
 5
[torch.LongStorage of size 6]
0 # a.storage_offset()
0 # b.storage_offset()
0 # c.storage_offset()
3 # d.storage_offset() ---- 变化了

另外,一个对象的id值可以认为是其在内存中的地址,比如 id(b.storage()) 。

2.1.3 内部实现

我们接下来看看内部实现。

TensorImpl 是 Tensor 内部实现。

代码语言:javascript复制
struct C10_API TensorImpl : public c10::intrusive_ptr_target {
  c10::impl::SizesAndStrides sizes_and_strides_;

  int64_t storage_offset_ = 0;
  caffe2::TypeMeta data_type_;  
  Storage storage_;

StorageImpl 则是 storage 的内部实现,可以看出来,storage是在DataPtr之上封装的接口。

代码语言:javascript复制
struct C10_API StorageImpl final : public c10::intrusive_ptr_target {

  DataPtr data_ptr_;
  size_t size_bytes_;
  bool resizable_;
  // Identifies that Storage was received from another process and doesn't have
  // local to process cuda memory allocation
  bool received_cuda_;
  Allocator* allocator_;  

2.2 定义

FusedOptimizer 通过将参数张量展平到一个或多个连续桶之中,就可以将多个模块参数更新内核融合为一个或少数几个。这里最主要的是对于 16位,32位参数来分别调用 flatten_module_params 做 flatten。

代码语言:javascript复制
class FusedOptimizer(torch.optim.Optimizer):
    """Convert any optimizer into a fused optimizer.

    This fused optimizer fuses multiple module parameter update kernel launches
    into one or a few, by flattening parameter tensors into one or more
    contiguous buckets.

    It can be used in conjunction with :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` method. In this case,
    Bagua will do the fusions automatically, otherwise, you need to explicitly
    set :attr:`do_flatten=True`.

    Args:
        optimizer (torch.optim.Optimizer): Any PyTorch optimizer.
        do_flatten (bool): Whether to flatten the parameters. Default: ``False``.

    Returns:
        Fused optimizer.


    Example::
        To use in conjunction with :meth:`~bagua.torch_api.distributed.BaguaModule.with_bagua` method:

        >>> optimizer = torch.optim.Adadelta(model.parameters(), ....)
        >>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer)
        >>> model = model.with_bagua([optimizer], GradientAllReduceAlgorithm())

        To use alone or with `torch.nn.parallel.DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=distributeddataparallel#torch.nn.parallel.DistributedDataParallel>`_,
        set :attr:`do_flatten=True`:

        >>> optimizer = torch.optim.Adadelta(model.parameters(), ....)
        >>> optimizer = bagua.torch_api.contrib.FusedOptimizer(optimizer, do_flatten=True)
    """

    def __init__(self, optimizer: torch.optim.Optimizer, do_flatten: bool = False):
        self.optimizer = copy.copy(optimizer)
        super(FusedOptimizer, self).__init__(optimizer.param_groups, optimizer.defaults)

        if do_flatten:
            f32_params = [ # 提取优化器参数之中32位参数
                param
                for group in self.optimizer.param_groups
                for param in group["params"]
                if param.type() == "torch.cuda.FloatTensor"
            ]
            f16_params = [ # 提取优化器参数之中16位参数
                param
                for group in self.optimizer.param_groups
                for param in group["params"]
                if param.type() == "torch.cuda.HalfTensor"
            ]

            # 然后分别打平
            flatten_module_params(f32_params, align_bytes=1)
            flatten_module_params(f16_params, align_bytes=1)

2.3 打平

把所有的 16 位 "params" 拷贝到一起,所有32位 "params" 拷贝到一起,逻辑是:

  • 初始化打平的权重张量 flatten_weights_tensor,并且指定了之前的设备。
  • 初始化打平的梯度张量 flatten_grads_tensor,并且指定了之前的设备。
  • 获取打平张量的storage。
  • 遍历参数列表:
    • 把权重拷贝到flatten张量,p.numel() 是元素个数,reshape(-1) 就是展平了,设置了存储信息。
    • 把梯度拷贝到flatten张量,p.numel() 是元素个数,reshape(-1) 就是展平了,设置了存储信息。
    • 设置底层的storage,size 和 strides,其实就是设置元信息。
  • 返回聚合打平之后的参数。
代码语言:javascript复制
def flatten_module_params(params_list, align_bytes: int):
    if len(params_list) == 0:
        return
    if not isinstance(params_list[0], list):
        params_list = [params_list]

    total_size = 0
    for params in params_list: # 计算参数总大小
        total_size  = _get_params_flattened_aligned_size(params, align_bytes)

    # 初始化打平的权重张量,并且指定了之前的设备    
    flatten_weights_tensor = torch.zeros(total_size, dtype=params_list[0][0].dtype).to(
        params_list[0][0].device
    )
    # 初始化打平的梯度张量,并且指定了之前的设备 
    flatten_grads_tensor = torch.zeros(total_size, dtype=params_list[0][0].dtype).to(
        params_list[0][0].device
    )

    # 获取打平张量的storage
    flatten_weights_storage = flatten_weights_tensor.storage()
    flatten_grads_storage = flatten_grads_tensor.storage()

    # 设置底层的storage,size 和 strides,其实就是设置元信息
    def set_storage(param, weight_storage, grad_storage, storage_offset):
        with torch.no_grad():
            z = torch.zeros_like(param.data)
            z.set_(weight_storage, storage_offset, param.shape)
            param.data = z

            t = torch.zeros_like(param.data)
            t.set_(grad_storage, storage_offset, param.shape)
            param.grad = t

    offset = 0
    for params in params_list: # 遍历参数列表
        for p in params:
            # copy data
            # 把权重拷贝到flatten,p.numel() 是元素个数,reshape(-1) 就是展平了,设置了存储信息
            flatten_weights_tensor[offset : offset   p.numel()] = p.data.reshape(-1)

            # 把梯度拷贝到flatten
            if p.grad is not None:
                flatten_grads_tensor[offset : offset   p.numel()] = p.grad.data.reshape(
                    -1
                )

            # flatten
            # 设置底层的storage,size 和 strides,其实就是设置元信息
            set_storage(p, flatten_weights_storage, flatten_grads_storage, offset)
            offset  = p.allocated_size

    # check
    for params in params_list:
        weight_tensors = [p.data for p in params]
        grad_tensors = [p.grad.data for p in params]

        assert check_contiguous(weight_tensors)
        assert check_contiguous(grad_tensors)

    # 返回聚合打平之后的参数    
    return new_param(flatten_weights_tensor, flatten_grads_tensor)

具体如下,这里假设都是32位的张量,就都被聚合到 f32_params 之中。flatten_module_params 就是处理之后的,属于被打平的张量。其中 group_1 的两个权重 param_wg11, param_wg12 被排列在一起。

代码语言:javascript复制
  --------------------------     --------------------------     --------------------------- 
 | group_1["params"]        |   | group_2["params"]        |   | group_3["params"]         |
 |                          |   |                          |   |                           |
 |  param_wg11 , param_gg11 |   |  param_wg21 , param_gg21 |   |  param_wg31 , param_gg31  |
 |  param_wg12 , param_gg12 |   |  param_wg22 , param_gg22 |   |  param_wg32 , param_gg32  |
 |                          |   |                          |   |                           |
  ------- ------------------     ---------- ---------------     ---------------- ---------- 
         |                                 |                                    |
         |                                 |                                    |
          --------------- ----------------- ------------------- ---------------- 
                         |                                     |
                         | f32_params                          | f16_params
                         |                                     |
                         v                                     v
 ------------------------ ---------------       --------------- --------------------------- 
| flatten_module_params                  |     | flatten_module_params                     |
|                                        |     |                                           |
|  ------------------------------------  |     |   -------------------------------------   |
| |flatten_weights_tensor              | |     |  |flatten_weights_tensor               |  |
| |                                    | |     |  |                                     |  |
| | param_wg11, param_wg12, param_wg21 | |     |  |               ......                |  |
| |                                    | |     |  |                                     |  |
| | param_wg22, param_wg31, param_wg32 | |     |  |                                     |  |
|  ------------------------------------  |     |   -------------------------------------   |
|  ------------------------------------  |     |   -------------------------------------   |
| |flatten_grads_tensor                | |     |  |  flatten_grads_tensor               |  |
| |                                    | |     |  |                                     |  |
| | param_gg11, param_gg12, param_gg21 | |     |  |                ......               |  |
| |                                    | |     |  |                                     |  |
| | param_gg22, param_gg31, param_gg32 | |     |  |                                     |  |
|  ------------------------------------  |     |   -------------------------------------   |
 ----------------------------------------       ------------------------------------------- 

2.4 优化

优化代码如下,具体是按照group遍历参数,对于每组参数:

  • 按照存储把参数分组。
  • 重新排序。
  • 再把融合的赋值回去。
代码语言:javascript复制
def step(self, closure=None):
    r"""Performs a single optimization step (parameter update).

    Args:
        closure (Callable): A closure that reevaluates the model and
            returns the loss. Optional for most optimizers.

    .. note::
        Unless otherwise specified, this function should not modify the
        ``.grad`` field of the parameters.
    """
    for group in self.optimizer.param_groups: # 按照group遍历参数
        params = group["params"]
        grouped_params = group_params_by_storage(params) # 按照存储把参数分组

        fused_params = []

        for _, group_p in grouped_params.items():
            fused_params.extend(reorder_params(group_p)) # 重新排序

        group["params"] = fused_params # 再把融合的赋值回去

    return self.optimizer.step(closure)
2.4.1 按照存储分组

其实,就是 32 位,16位,weight,grad 一共四种组合。

比如针对 group_1 拿到了 32 位的权重 param_wg11, param_wg12,因为他们的 p.data.storage().data_ptr() 一致,所以把这个数值作为key,把这些权重放在同样 key 对应的位置。

代码语言:javascript复制
def group_params_by_storage(params):
    grouped_params = {}
    for p in params:
        weight_storage = p.data.storage().data_ptr() # 拿到key
        param_list = grouped_params.get(weight_storage, [])
        param_list.append(p) 
        grouped_params[weight_storage] = param_list # 放进value

    return grouped_params
2.4.2 重新排序

对于同样key 的参数,按照 storage offset 进行排序。

代码语言:javascript复制
def reorder_params(params):
    """Input params share same storage, reorder them by their storage offset"""

    sorted_params = sorted(params, key=lambda x: x.storage_offset())

    grouped = []
    tmp_params = []

    for p in sorted_params:
        if len(tmp_params) > 0 and not is_contiguous_param(p, tmp_params[-1]):
            grouped.append(collocate_params(tmp_params))
            tmp_params = []

        tmp_params.append(p)

    if len(tmp_params) > 0:
        grouped.append(collocate_params(tmp_params))  # FIXME: potential OOM

    return grouped

整个优化大致如下:

最开始时候是 group'params' = list(param_wg11, param_wg12) ,两个item 的list,两次CUDA操作。

结束时候是 group'params' = list(param_wg11 param_wg12) ,一个 item 的list,这里就融合了,缩减为一次CUDA操作。

代码语言:javascript复制
 ----------------------------------------       ------------------------------------------- 
| flatten_module_params                  |     | flatten_module_params                     |
|                                        |     |                                           |
|  ------------------------------------  |     |   -------------------------------------   |
| |flatten_weights_tensor              | |     |  |flatten_weights_tensor               |  |
| |                                    | |     |  |                                     |  |
| | param_wg11, param_wg12, param_wg21 | |     |  |               ......                |  |
| |                                    | |     |  |                                     |  |
| | param_wg22, param_wg31, param_wg32 | |     |  |                                     |  |
|  ------------------------------------  |     |   -------------------------------------   |
|  ------------------------------------  |     |   -------------------------------------   |
| |flatten_grads_tensor                | |     |  |  flatten_grads_tensor               |  |
| |                                    | |     |  |                                     |  |
| | param_gg11, param_gg12, param_gg21 | |     |  |                ......               |  |
| |                                    | |     |  |                                     |  |
| | param_gg22, param_gg31, param_gg32 | |     |  |                                     |  |
|  ------------------------------------  |     |   -------------------------------------   |
 ----------------------------------------       ------------------------------------------- 

 ------------------------------------------- ---------------------------------------------- 
                                            |
                                            |
                                            v
            ------------------------------------------------------------------------ 
           | step()                         |                                       |
           |                                |                                       |
           |                                |                                       |
           |                                v            2 items list               |
           |                                                                        |
           |                   group['params'] = list(param_wg11, param_wg12)       |
           |                                                                        |
           |                                |                                       |
           |                                |                                       |
           |                                v                                       |
           |                     group_params_by_storage / reorder_params           |
           |                                                                        |
           |                                |                                       |
           |                                |                                       |
           |                                v                                       |
           |          grouped_params] = {140266160612352 : param_wg11, param_wg12}  |
           |                                                                        |
           |                                |                                       |
           |                                |            1 item list                |
           |                                v                                       |
           |                group['params'] = list(param_wg11   param_wg12)         |
           |                                                                        |
           |                                                                        |
           |                                |                                       |
           |                                v                                       |
           |                    self.optimizer.step(closure)                        |
           |                                                                        |
           |                                                                        |
            ------------------------------------------------------------------------ 

0x03 分层化 --- 进程组

3.1 设计思路

Bagua的设计思路如下:

分层化的通信实现:由于工业级别的分布式训练往往需要多机多卡,而不同物理连接方式所带来的延时和带宽也有较大差异,因此,通讯的有效抽象也对性能的提升至关重要。Bagua 将涉及多机的通信抽象成:“机内”和“机间”,并对于相应的通信抽象做了优化。例如,对于信息压缩传输,分层化通讯将会把这一算法解读成“机内”完整精度,“机间”信息压缩,从而为不同的物理链接提供最合适的通信算法。 我们想要强调的是,这些系统实现层面的优化是对于各种算法组合广泛适用,而非局限在某一特定的算法设置上。因此,所有的系统优化都可以被灵活的复用到各种算法实现中去,这在保证“端到端”的性能提升的同时,也为开发新的分布式算法提供了良好的平台。

我们接下来就看看如何通过进程组实现分层化通信。分析思路就是:

  • 分层通信是不是有多个对应的进程组?
  • 如何得到节点内通信进程组的ranks?
  • 如何得到节点间通信进程组使用的ranks?
  • 每个进程组都有自己独立的通信方法吗?
  • 通信时候如何进行分层通信?

3.2 生成进程组

我们可以从源码之中的测试文件之中找到如何生成一个新进程组。

代码语言:javascript复制
all_ranks = list(range(nprocs))
odd_ranks = list(filter(lambda r: r % 2 == 1, all_ranks))
g = bagua.communication.new_group(ranks=odd_ranks)

new_group 此功能要求默认组中的所有进程(即作为分布式作业一部分的所有进程)都执行这个函数,即使它们不是组的成员。其参数是:

  • ranks :组成员的ranks列表。
  • stream : 执行NCCL操作的CUDA流。
代码语言:javascript复制
def new_group(
    ranks: Optional[List[int]] = None, stream: Optional[torch.cuda.Stream] = None
):
    """
    Creates a new process group.

    This function requires that all processes in the default group (i.e. all
    processes that are part of the distributed job) enter this function, even
    if they are not going to be members of the group. Additionally, groups
    should be created in the same order in all processes.

    Each process group will create three communicators on request, a global communicator,
    a inter-node communicator and a intra-node communicator. Users can access them through
    ``group.get_global_communicator()``, ``group.get_inter_node_communicator()``
    and ``group.get_intra_node_communicator()`` respectively.

    Args:
        ranks: List of ranks of group members. If ``None``, will be
            set to all ranks. Default is ``None``.
        stream: A CUDA stream used to execute NCCL operations. If ``None``,
            CUDA stream of the default group will be used. See
            `CUDA semantics <https://pytorch.org/docs/stable/notes/cuda.html?highlight=stream>`_
            for details.

    Returns:
        A handle of process group that can be given to collective calls.

    .. note::
        The global communicator is used for global communications involving all ranks in the process group.
        The inter-node communicator and the intra-node communicator is used for hierarchical communications
        in this process group.

    .. note::
        For a specific communicator ``comm``, ``comm.rank()`` returns the rank of current process and
        ``comm.nranks()`` returns the size of the communicator.
    """
    global _group_count
    global _pg_group_ranks
    global _pg_map

    _group_count  = 1

    if ranks is None:
        ranks = list(range(get_world_size()))
    else:
        ranks = sorted(ranks) # 排序

    if stream is None:
        _check_default_pg()
        stream = _get_default_group().stream

    group_name = str(_group_count)
    pg = BaguaProcessGroup(ranks, stream, str(_group_count)) # 生成进程组
    
    # Create the global rank to group rank mapping
    _pg_group_ranks[pg] = {
        global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
    }
    _pg_map[group_name] = pg

    return pg

3.3 Ranks

我们接着看看两个全局变量如何计算,一个是层内的ranks,一个是层间的ranks。

代码语言:javascript复制
intra_ranks = list(
    filter(
        lambda rank: rank // get_local_size() == get_rank() // get_local_size(),
        ranks,
    )
)
inter_ranks = list(
    filter(
        lambda rank: rank % get_local_size() == ranks[0] % get_local_size(),
        ranks,
    )
)

Python 的操作符如下:

//

取整除 - 返回商的整数部分(向下取整)

9//2 是 4 , -9//2 是 -5

%

取模 - 返回除法的余数

b % a 输出结果 0

实验一下

代码语言:javascript复制
def get_rank() -> int:
    return 5
def get_local_size():
    return 3
    
nprocs = 10 # 10个进程
ranks = list(range(nprocs)) # rank是0~9
print(intra_ranks) # rank 5 所在的intra_ranks。
print(inter_ranks) # 总的inter_ranks,能看出来是在 local size 的边缘。

输出
[3, 4, 5] # intra_ranks
[0, 3, 6, 9] # inter_ranks,在 local size 3 的边缘

具体用到的几个函数如下:

代码语言:javascript复制
def get_rank() -> int:
    """
    Get the rank of current process group.

    Rank is a unique identifier assigned to each process within a distributed
    process group. They are always consecutive integers ranging from 0 to
    ``world_size``.

    Returns:
        The rank of the process group.
    """
    return int(os.environ.get("RANK", 0))


def get_local_rank() -> int:
    """
    Get the rank of current node.

    Local rank is a unique identifier assigned to each process within a node.
    They are always consecutive integers ranging from 0 to ``local_size``.

    Returns:
        The local rank of the node.
    """
    return int(os.environ.get("LOCAL_RANK", 0))
  
  
def get_local_size() -> int:
    """
    Get the number of processes in the node.

    Returns:
        The local size of the node.
    """
    return int(os.environ.get("LOCAL_WORLD_SIZE", 1))  

现在我们知道了,不同进程组内部的ranks如何得到。

3.4 BaguaProcessGroup 定义

我们接下来看看 BaguaProcessGroup 如何定义,从定义上看,每个进程组都建立了三个 communicators,分别是:

  • a global communicator,使用 group.get_global_communicator() 可以得到。
  • a inter-node communicator,使用 group.get_inter_node_communicator() 可以得到。
  • a intra-node communicator,使用 group.get_intra_node_communicator() 可以得到。

全局通讯器用于进程组中所有ranks的全局通讯。节点间(inter-node)通讯器和节点内(intra-node)通讯器用于此过程组中的分层(hierarchical)通讯。

启用分层通信(hierarchical communication)。这意味着同一台机器上的GPU将首先相互通信。之后,机器进行节点间通信。这可以在节点间通信成本较高时提高性能。

代码语言:javascript复制
class BaguaProcessGroup:
    def __init__(self, ranks, stream, group_name):
        self.ranks = ranks
        self.stream = stream
        self.group_name = group_name

        self.intra_ranks = list(
            filter(
                lambda rank: rank // get_local_size() == get_rank() // get_local_size(),
                ranks,
            )
        )
        self.inter_ranks = list(
            filter(
                lambda rank: rank % get_local_size() == ranks[0] % get_local_size(),
                ranks,
            )
        )

    def get_global_communicator(self):
        return get_communicator(self.group_name, "global")

    def get_inter_node_communicator(self):
        return get_communicator(self.group_name, "inter")

    def get_intra_node_communicator(self):
        return get_communicator(self.group_name, "intra")

3.5 生成 communicator

具体就是生成了 BaguaSingleCommunicatorPy。这里使用了 lru_cache 来保证只生成一次。BaguaSingleCommunicatorPy 定义在 rust/bagua-core/bagua-core-py/src/lib.rs,在 rust/bagua-core/bagua-core-internal/src/communicators/mod.rs 之中也有 BaguaHierarchicalCommunicator 和 HierarchicalCommunicator 这样的实现 ,这就不是我们重点了,有兴趣的读者可以深入研究。

代码语言:javascript复制
@lru_cache(maxsize=None)
def get_communicator(group_name: str, comm_name: str):
    global _pg_map

    pg = _pg_map[group_name]
    if comm_name == "global":
        ranks = pg.ranks
    elif comm_name == "inter":
        ranks = pg.inter_ranks
    elif comm_name == "intra":
        ranks = pg.intra_ranks
    else:
        raise ValueError("comm_name should be one of ['global', 'inter', 'intra']")

    comm_key = "{}_{}_{}".format(group_name, comm_name, ",".join(map(str, ranks)))

    nccl_unique_id = broadcast_nccl_unique_id(comm_key, root=ranks[0])

    if get_rank() not in ranks:
        return CommMember.NON_COMM_MEMBER

    rank = ranks.index(get_rank())
    nranks = len(ranks)

    comm = B.BaguaSingleCommunicatorPy(
        rank=rank,
        nranks=nranks,
        device_id=get_local_rank(),
        stream_ptr=pg.stream.cuda_stream,
        nccl_unique_id_str=nccl_unique_id,
    )

    comm.cuda_stream = pg.stream
    return comm

具体如下:

代码语言:javascript复制
 ----------------------------------- 
| BaguaProcessGroup                 |
|                                   |        --------------------------- 
|                                   |       | BaguaSingleCommunicatorPy |
|                                   |       |                           |
|    get_global_communicator   -----------> |      ranks                |
|                                   |       |                           |
|                                   |        --------------------------- 
|                                   |
|                                   |        --------------------------- 
|                                   |       | BaguaSingleCommunicatorPy |
|    get_inter_node_communicator  --------> |                           |
|                                   |       |      inter_ranks          |
|                                   |       |                           |
|                                   |        --------------------------- 
|                                   |
|                                   |        --------------------------- 
|    get_intra_node_communicator  --------> | BaguaSingleCommunicatorPy |
|                                   |       |                           |
|                                   |       |      intra_ranks          |
|    ranks                          |       |                           |
|                                   |        --------------------------- 
|    inter_ranks                    |
|                                   |
|    intra_ranks                    |
|                                   |
|                                   |
 ----------------------------------- 

3.6 使用

具体代码在:rust/bagua-core/bagua-core-internal/src/communicators/mod.rs

可以看到,如果没有设置hierarchical,就正常通信,如果设置hierarchical,就用intra 和 inter 混合着来,先试验 intra,再节点间通信。

代码语言:javascript复制
impl BaguaCommunicator {
    pub fn new(
        communicator_internode: Option<&BaguaSingleCommunicator>,
        communicator_intranode: Option<&BaguaSingleCommunicator>,
        hierarchical: bool,
    ) -> Result<Self, BaguaCoreError> {
        match hierarchical {
            false => Ok(BaguaCommunicator::SingleCommunicator( // 不是 hierarchical,就正常通信
                communicator_internode
                    .expect("inter node communicator must be given in non-hierarchical mode")
                    .clone(),
            )),
            true => { // 是 hierarchical,就用intra 和 inter 混合着来,先试验 intra
                let intranode_rank = communicator_intranode.as_ref().unwrap().rank();
                if intranode_rank == 0 {
                    let intra = communicator_intranode.expect("intra node communicator must be given in worker GPU in hierarchical mode").clone();
                    let inter = communicator_internode.unwrap().clone();
                    {
                        if intra.inner.stream_ptr != inter.inner.stream_ptr {
                            return Err(BaguaCoreError::CommunicatorError("intra node communicator should use the same stream as the inter node communicator".into()));
                        }
                    }
                    Ok(BaguaCommunicator::HierarchicalCommunicator(
                        BaguaHierarchicalCommunicator::Leader(
                            BaguaHierarchicalCommunicatorLeader::new(inter, intra),
                        ),
                    ))
                } else {
                    Ok(BaguaCommunicator::HierarchicalCommunicator(BaguaHierarchicalCommunicator::Worker(BaguaHierarchicalCommunicatorWorker {
                        intranode: communicator_intranode.expect("intra node communicator must be given in worker GPU in hierarchical mode").clone()
                    })))
                }
            }
        }
    }

    pub fn execute_communication(
        &self,
        tensor: &mut BaguaCommunicationTensor,
        intranode_average: bool,
        hierarchical_pre: bool,
        hierarchical_post: bool,
        communication_hook: &mut dyn FnMut(
            &BaguaCommunicatorInner,
            &mut BaguaCommunicationTensor,
        ) -> (),
    ) {
        match &self {
            BaguaCommunicator::SingleCommunicator(communicator) => {
                let communicator = communicator.inner.clone();
                communication_hook(&communicator, tensor);
            }
            BaguaCommunicator::HierarchicalCommunicator(communicator) => match communicator {
                BaguaHierarchicalCommunicator::Leader(communicator) => {
                    let internode_communicator = communicator.internode.inner.clone();
                    if hierarchical_pre { // 先节点内部
                        communicator.hierarchical_pre(tensor, intranode_average);
                    }
                    communication_hook(&internode_communicator, tensor); // 再节点间
                    if hierarchical_post {
                        communicator.hierarchical_post(tensor);
                    }
                }
                BaguaHierarchicalCommunicator::Worker(communicator) => {
                    if hierarchical_pre {
                        communicator.hierarchical_worker_pre(tensor, intranode_average);
                    }
                    if hierarchical_post {
                        communicator.hierarchical_worker_post(tensor);
                    }
                }
            },
        }
    }
}

0xFF 参考

PyTorch internals

快手八卦!突破 TensorFlow、PyTorch 并行瓶颈的开源分布式训练框架来了!

https://arxiv.org/pdf/2107.01499.pdf

1 Dean, Jeffrey, Greg S. Corrado, Rajat Monga, Kai Chen, Matthieu Devin, Quoc V. Le, Mark Z. Mao et al. “Large scale distributed deep networks.” (2012).

2 Zhengyuan Zhou, Panayotis Mertikopoulos, Nicholas Bambos, Peter Glynn, Yinyu Ye, Li-Jia Li, and Li Fei-Fei. 2018. Distributed asynchronous optimization with unbounded delays: How slow can you go?. In International Conference on Machine Learning. PMLR, 5970–5979.

3 DanAlistarh, DemjanGrubic, JerryLi, RyotaTomioka, and MilanVojnovic. 2016. QSGD: Communication-efficient SGD via gradient quantization and encoding. arXiv preprint arXiv:1610.02132 (2016).

4 Dan Alistarh, Torsten Hoefler, Mikael Johansson, Sarit Khirirat, Nikola Konstanti- nov, and Cédric Renggli. 2018. The convergence of sparsified gradient methods. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 5977–5987.

5 Anastasia Koloskova, Sebastian Stich, and Martin Jaggi. 2019. Decentralized stochastic optimization and gossip algorithms with compressed communication. In International Conference on Machine Learning. PMLR, 3478–3487.

6 Xiangru Lian, Ce Zhang, Huan Zhang, Cho-Jui Hsieh, Wei Zhang, and Ji Liu. 2017. Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent. In Proceedings of the 31st International Conference on Neural Information Processing Systems. 5336–5346.

7 Christopher De Sa, Matthew Feldman, Christopher Ré, and Kunle Olukotun. 2017. Understanding and optimizing asynchronous low-precision stochastic gradient descent. In Proceedings of the 44th Annual International Symposium on Computer Architecture. 561–574.

8 Xiangru Lian, Wei Zhang, Ce Zhang, and Ji Liu. 2018. Asynchronous decentral- ized parallel stochastic gradient descent. In International Conference on Machine Learning. PMLR, 3043–3052.

9 Hanlin Tang, Shaoduo Gan, Ce Zhang, Tong Zhang, and Ji Liu. 2018. Com- munication compression for decentralized training. In Proceedings of the 32nd International Conference on Neural Information Processing Systems. 7663–7673.

10 Ji Liu, Ce Zhang, et al. 2020. Distributed Learning Systems with First-Order Methods. Foundations and Trends® in Databases 9, 1 (2020), 1–100.

0 人点赞