[源码解析] PyTorch 分布式(17) --- 结合DDP和分布式 RPC 框架

2021-12-17 16:35:53 浏览数 (1)

[源码解析] PyTorch 分布式(17) --- 结合DDP和分布式 RPC 框架

目录

  • [源码解析] PyTorch 分布式(17) --- 结合DDP和分布式 RPC 框架
    • 0x00 摘要
    • 0x00 综述
    • 0x01 启动
    • 0x03 支撑系统
      • 3.1 功能
      • 3.2 使用
        • 3.2.1 混合模型
        • 3.2.2 使用
      • 3.3 定义
      • 3.4 主要函数
    • 0x04 HybridModel
    • 0x05 训练
      • 5.1 初始化
      • 5.2 训练循环
    • 0x06 比对
    • 0xFF 参考

0x00 摘要

在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,接下来我们通过几篇文章来看看如何把这些模块应用到实践之中,顺便把PyTorch分布式逻辑整体梳理一下。本文介绍如何把DDP和RPC framework结合起来。

本文以 COMBINING DISTRIBUTED DATAPARALLEL WITH DISTRIBUTED RPC FRAMEWORK 的翻译为基础,加入了自己的理解

0x00 综述

本教程使用一个简单的示例来演示如何将 DistributedDataParallel (DDP) 与分布式 RPC 框架 相结合,将分布式数据并行性与分布式模型并行性相结合,以训练一个简单的模型。该示例的源代码可以在这里找到。

前面的教程 入门分布式数据并行 和入门分布式RPC框架 分别描述了如何执行分布式数据并行和分布式模型平行训练。尽管如此,您可能希望在多种训练范式中结合这两种技术。例如:

  1. 如果我们有一个包含稀疏部分(大型嵌入表)和密集部分(FC 层)的模型,我们可能希望将嵌入表放在参数服务器上,并使用DistributedDataParallel在多个trainer之间复制 FC 层。分布式RPC框架 就可被用于在参数服务器上执行嵌入查找。
  2. 如PipeDream论文中所述启用混合并行性。我们可以使用分布式 RPC 框架 将模型的各个阶段跨多个worker 进行流水线化,并使用DistributedDataParallel 对每个阶段进行数据并行(如果需要)。

在本教程中,我们将介绍上述案例 1。我们的设置中共有 4 个 worker,如下所示:

  • 1 个Master,负责在参数服务器上创建嵌入表(nn.EmbeddingBag)。master 还负责驱动两个trainer上的训练循环。
  • 1 个Parameter Server,它将嵌入表保存在内存中,并响应来自 Master 和 Trainer 的 RPC 请求。
  • 2 个trainer,它存储一个 FC 层 (nn.Linear),其使用DistributedDataParallel 进行数据并行。trainer还负责执行前向传播、后向传播和优化器步骤。

整个训练过程执行如下:

  1. Master 创建一个RemoteModule ,在参数服务器上保存一个嵌入表。
  2. Master 在trainer上启动训练循环,并将远程模块(remote module)传播给trainer。
  3. Trainer 创建一个HybridModel,其首先使用 master 提供的远程模块执行嵌入查找(embedding lookup),然后执行封装在 DDP 中的 FC 层。
  4. Trainer 执行模型的前向传播,并使用Distributed Autograd 对损失执行后向传播。
  5. 作为反向传播的一部分,首先计算 FC 层的梯度,并通过 DDP 中的 allreduce 同步到所有trainer。
  6. 接下来,分布式 Autograd 将梯度传播到参数服务器,在那里更新嵌入表的梯度。
  7. 最后,分布式优化器被用于更新所有参数。

注意:如果您将 DDP 和 RPC 结合使用,则应始终使用Distributed Autograd进行反向传播。

0x01 启动

我们看看系统如何启动。首先,在进行训练之前,需要设置所有worker。我们创建了 4 个进程,其中 rank 0 和 rank 1 是我们的trainer,rank 2是master,rank 3是参数服务器。

初始化逻辑如下:

  • 我们使用 TCP init_method 在所有 4 个 worker 上初始化 RPC 框架。
  • 对于 Master,代码做了如下操作:
    • 完成 RPC 初始化后,master 创建一个远程模块RemoteModule,该模块指向一个在参数服务器上保存的EmbeddingBag层。
    • 然后 master 遍历每个trainer,并通过使用rpc_async调用_run_trainer 在每个trainer之上启动训练循环。
    • 最后,master 在退出之前等待所有训练完成。
  • Trainers做了如下操作:
    • Trainers 首先使用 init_process_group为DDP初始化一个world_size = 2(对于两个trainer)的ProcessGroup
    • 接下来,Trainers 使用 TCP init_method 初始化 RPC 框架。注意RPC初始化和ProcessGroup初始化的端口是不同的。这是为了避免两个框架的初始化之间的端口冲突。
    • 初始化完成后,trainer只需等待来自 master的_run_trainer RPC。
  • 参数服务器只是初始化 RPC 框架并等待来自trainer和master的 RPC。

具体代码如下:

代码语言:javascript复制
def run_worker(rank, world_size):
    r"""
    A wrapper function that initializes RPC, calls the function, and shuts down
    RPC.
    """

    # We need to use different port numbers in TCP init_method for init_rpc and
    # init_process_group to avoid port conflicts.
    rpc_backend_options = TensorPipeRpcBackendOptions()
    rpc_backend_options.init_method = "tcp://localhost:29501"

    # Rank 2 is master, 3 is ps and 0 and 1 are trainers.
    if rank == 2: # Master代码
        rpc.init_rpc(
            "master",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        remote_emb_module = RemoteModule( # 指向一个在参数服务器上保存的EmbeddingBag层
            "ps",
            torch.nn.EmbeddingBag,
            args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
            kwargs={"mode": "sum"},
        )

        # Run the training loop on trainers.
        futs = []
        for trainer_rank in [0, 1]:
            trainer_name = "trainer{}".format(trainer_rank)
            fut = rpc.rpc_async( # 启动 trainer循环
                trainer_name, _run_trainer, args=(remote_emb_module, trainer_rank)
            )
            futs.append(fut)

        # Wait for all training to finish.
        for fut in futs:
            fut.wait()
    elif rank <= 1:
        # Initialize process group for Distributed DataParallel on trainers.
        dist.init_process_group(
            backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
        )

        # Initialize RPC.
        trainer_name = "trainer{}".format(rank)
        rpc.init_rpc(
            trainer_name,
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )

        # 只需等待来自 master的 _run_trainer RPC
        # Trainer just waits for RPCs from master.
    else:
        rpc.init_rpc( # 参数服务器
            "ps",
            rank=rank,
            world_size=world_size,
            rpc_backend_options=rpc_backend_options,
        )
        # parameter server do nothing
        pass # 啥也不干,只是等待来自trainer和master的 RPC

    # block until all rpcs finish
    rpc.shutdown()


if __name__ == "__main__":
    # 2 trainers, 1 parameter server, 1 master.
    world_size = 4
    mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)

目前逻辑如下,我们后续会继续拓展:

代码语言:javascript复制
                               torch.multiprocessing.spawn
                                           
                                          |
                                          |
               ---------------------------------------------------------------- ---------------------------------- 
              |                           |                                    |                                  |
              |                           |                                    |                                  |
              v                           v                                    v                                  v
 ------------- -------------    ---------- ---------------   ------------------ ------------------   ------------- -------- 
|trainer 0         rank = 0 |  |trainer 1     rank = 1    | | master                     rank = 2 | |ps          rank = 3  |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |   rpc.init_rpc                      | |     rpc.init_rpc     |
|                           |  |                          | |                                     | |                      |
|   dist.init_process_group |  |  dist.init_process_group | |   remote_emb_module =  RemoteModule | |                      |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |                                     | |                      |
|   rpc.init_rpc            |  |  rpc.init_rpc            | |   fut = rpc.rpc_async(_run_trainer) | |                      |
|                           |  |                          | |                                     | |                      |
|                           |  |                          | |                                     | |                      |
 ---------------------------    --------------------------   -------------------------------------   ---------------------- 

手机如下:

0x03 支撑系统

支撑系统主要指的就是 _RemoteModule,其作用是在异地建立一个模型,具体代码在:torch/distributed/nn/api/remote_module.py。

3.1 功能

RemoteModule实例只能在RPC初始化之后创建,它可以在指定的远程节点上创建用户指定的模块,其行为类似于常规的nn.Module方法,但不同之处是 RemoteModule 在远程节点上执行forward方法。RemoteModule 负责autograd recording,以确保向后传播可以将梯度传播回相应的远程模块。

RemoteModule 可以使用RPC framework <https://pytorch.org/docs/stable/rpc.html> 在处理器之间共享,且不会产生复制实际模块的任何开销,这相当于使用一个~torch.distributed.rpc.RRef指向远程模块。

3.2 使用

3.2.1 混合模型

要创建混合模型,通常应该在远程模块之外创建本地模块,而不是作为任何远程模块的子模块。如果远程模块放置在cuda设备上,那么任何输入CPU张量将自动移动到同一cuda设备之上。混合模型例子如下:

代码语言:javascript复制
            >>> class HybridModel(nn.Module):
            >>>     def __init__(self):
            >>>         nn.Module.__init__(self)
            >>>         self.remote_embedding = RemoteModule(...) # 在远端创建嵌入层
            >>>         self.local_linear = nn.Linear(...)
3.2.2 使用

使用例子如下,需要在两个不同进程上运行如下代码,例子之中,RemoteModule 创建时候,传入了一个"worker1/cpu"参数,意思是在 worker1 的 cpu 设备上运行这个RemoteModule。具体格式是: <workername> / <device>,其中 <device> 是torch.device类型。

代码语言:javascript复制
    Example::
        >>> # On worker 0:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>> from torch import nn, Tensor
        >>> from torch.distributed.nn.api.remote_module import RemoteModule
        >>>
        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
        >>> remote_linear_module = RemoteModule(
        >>>     "worker1/cpu", nn.Linear, args=(20, 30),
        >>> )
        >>> input = torch.randn(128, 20)
        >>> ret_fut = remote_linear_module.forward_async(input)
        >>> ret = ret_fut.wait()
        >>> rpc.shutdown()

        >>> # On worker 1:
        >>> import torch
        >>> import torch.distributed.rpc as rpc
        >>>
        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
        >>> rpc.shutdown()

3.3 定义

_RemoteModule定义如下,具体初始化逻辑是:

  • (1). 准备参数。
  • (2). 设置运行的远端worker和远端设备。
  • (3). 如果设置了_module_interface_cls
    • (3.1) 使用 _module_interface_cls 来在远端构建模块。_
    • (3.2) 在本地构建函数代理生成器。
    • (3.3) 等待创建完成。
    • (3.4) 在本地构建句柄。
  • (4) 没有设置_module_interface_cls。
    • (4.1) 在本地构建函数代理生成器。
    • (4.2) 在远端创建模块。
  • (5). 在本地创建远端函数代理。
代码语言:javascript复制
class _RemoteModule(nn.Module):
    def __init__(
        self,
        remote_device: str,
        module_cls: nn.Module,
        args: Tuple = None,
        kwargs: Dict[str, Any] = None,
        _module_interface_cls: Any = None,
    ):
        """
        Args:
            remote_device (str): Device on the destination worker where we'd like to place this module.
                The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
                E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
                In addition, the device field can be optional and the default value is "cpu".

        Returns:
            A remote module instance which wraps the :class:`~nn.Module` created by the
            user-provided ``module_cls``, it has a blocking ``forward`` method and an
            asynchronous ``forward_async`` method that returns a future of the ``forward`` call
            on the user-provided module on the remote side.
        """
        super().__init__()

        # NOTE: if a new attribute is added to this class, also need to add it
        # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling.

        # Default arguments preperation.
        # 1. 准备参数
        args = args if args is not None else ()
        kwargs = kwargs if kwargs is not None else {}

        # 2. 设置运行的远端worker和远端设备
        self.on, self.device = _parse_remote_device(remote_device)
        agent = rpc._get_current_rpc_agent()
        # If the device map of the remote worker is set,
        # then enable moving any input CPU tensors to the same cuda device.
        self.is_device_map_set = bool(
            agent._get_device_map(agent.get_worker_info(self.on))
        )
        # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``:
        # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set,
        # then any CPU tensors can still be moved to a cuda device to run forward,
        # but the output must be moved back to CPU before being sent over the wire.
        enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda"

        # 3. 如果设置了_module_interface_cls
        if _module_interface_cls is not None:
            # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
            self.is_scriptable = True

            # 3.1 使用 _module_interface_cls 来在远端构建模块
            # Instantiate template on remote side.
            fut = rpc.rpc_async(
                self.on,
                _instantiate_template,
                (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
            )

            # 3.2 在本地构建函数代理生成器
            # Instantiate template on local side.
            generated_module = (
                instantiator.instantiate_scriptable_remote_module_template(
                    _module_interface_cls, enable_moving_cpu_tensors_to_cuda
                )
            )
            self.generated_methods = generated_module._generated_methods

            # 3.3 等待创建完成
            # Create the module on the remote side.
            fut.wait()  # Ensure remote_module_cls is available on remote side.

            # 3.4 在本地构建句柄
            self.module_rref = rpc.rpc_sync(
                self.on,
                _create_module_with_interface,
                (module_cls, args, kwargs, self.device, _module_interface_cls),
            )
        else: # 4 没有设置_module_interface_cls
            self.is_scriptable = False
            # 4.1 在本地构建函数代理生成器
            self.generated_methods = (
                _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
            )
            # 4.2 在远端创建模块
            # Create the module on the remote side.
            self.module_rref = rpc.remote(
                self.on,
                _create_module,
                (module_cls, args, kwargs, self.device),
            )

        # Install generated methods.
        # 5. 在本地创建远端函数代理
        for method in self.generated_methods:
            method_name = method.__name__
            method = torch.jit.export(method)
            setattr(self, method_name, types.MethodType(method, self))

3.4 主要函数

其主要函数如下:

  • rpc.rpc_sync 返回指向远程模块参数的~torch.distributed.rpc.RRef列表。通常可以与~torch.distributed.optim.DistributedOptimizer结合使用。
  • get_module_rref 返回一个指向远程模块的~torch.distributed.rpc.RRef(RRef[nn.Module])类。
代码语言:javascript复制
def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
    """
    Returns a list of :class:`~torch.distributed.rpc.RRef` pointing to the
    remote module's parameters. This can typically be used in conjuction
    with :class:`~torch.distributed.optim.DistributedOptimizer`.

    Args:
        recurse (bool): if True, then returns parameters of the remote
            module and all submodules of the remote module. Otherwise,
            returns only parameters that are direct members of the
            remote module.

    Returns:
        A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``)
        to remote module's parameters.
    """
    return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))

def get_module_rref(self) -> rpc.RRef[nn.Module]:
    """
    Returns an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``)
    pointing to the remote module.
    """
    return self.module_rref

于是逻辑图转换如下,在上图基础之上多了一个remote_emb_module,其在ps之上创建了一个RemoteModule

代码语言:javascript复制
                                torch.multiprocessing.spawn
                                            
                                           |
                                           |
                ---------------------------------------------------------------- ---------------------------------- 
               |                           |                                    |                                  |
               |                           |                                    |                                  |
               v                           v                                    v                                  v
 -------------- -------------   ----------- --------------   ------------------- -----------------    ------------- -------- 
|trainer 0          rank = 0 | |trainer 1     rank = 1    | | master                     rank = 2 |  |ps          rank = 3  |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |     rpc.init_rpc                    |  |     rpc.init_rpc     |
|                            | |                          | |                                     |  |                      |
|    dist.init_process_group | |  dist.init_process_group | |   remote_emb_module  ----------------------> RemoteModule     |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
|    rpc.init_rpc            | |  rpc.init_rpc            | |   fut = rpc.rpc_async(_run_trainer) |  |                      |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
|                            | |                          | |                                     |  |                      |
 ----------------------------   --------------------------   -------------------------------------    ---------------------- 

手机如下:

0x04 HybridModel

在讨论 Trainer 的细节之前,让我们先介绍一下 Trainer使用的HybridModel。该模型由稀疏部分和稠密部分组成。

  • 稠密部分是一个nn.Linear,使用DistributedDataParallel在所有trainer中复制,即 在 DDP 内包装了一个 nn.Linear层。
  • 稀疏部分是一个远程模块 (remote_emb_module) ,它持有一个在参数服务器上的nn.EmbeddingBag。即,此远程模块可以获取参数服务器上嵌入表的远程引用。

该模型的前向方法非常简单。它使用 RemoteModule 在参数服务器上执行嵌入查找forward ,并将其输出传播到 FC 层,这里的 FC 使用了DDP

代码语言:javascript复制
class HybridModel(torch.nn.Module):
    r"""
    The model consists of a sparse part and a dense part.
    1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
    2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
    This remote model can get a Remote Reference to the embedding table on the parameter server.
    """

    def __init__(self, remote_emb_module, device):
        super(HybridModel, self).__init__()
        self.remote_emb_module = remote_emb_module
        self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
        self.device = device

    def forward(self, indices, offsets):
        emb_lookup = self.remote_emb_module.forward(indices, offsets)
        return self.fc(emb_lookup.cuda(self.device))

逻辑拓展如下,两个trainer 之上也建立了remote_emb_module,指向了ps之上的RemoteModule

代码语言:javascript复制
                                         torch.multiprocessing.spawn
                                                     
                                                    |
                                                    |
             ----------------------------------------------------------------------------------- ---------------------------------- 
            |                                       |                                           |                                  |
            |                                       |                                           |                                  |
            v                                       v                                           v                                  v
 ----------- -------------   ----------------------- -------------------   --------------------- ---------------      ------------- -------- 
|trainer 0       rank = 0 | | trainer 1                        rank = 1 | | master                     rank = 2 |    |ps          rank = 3  |
|                         | |                                           | |                                     |    |                      |
|                         | |                                           | |   rpc.init_rpc                      |    |     rpc.init_rpc     |
| dist.init_process_group | | dist.init_process_group                   | |                                     |    |                      |
|                         | |                                           | |   remote_emb_module  ------------------------> RemoteModule     |
| rpc.init_rpc            | | rpc.init_rpc                              | |                                     |    |         ^     ^      |
|                         | |                                           | |                                     |    |         |     |      |
|                         | |                                           | |   fut = rpc.rpc_async(_run_trainer) |    |         |     |      |
|                         | |                                           | |                                     |    |         |     |      |
|  ---------------------  | |             ---------------------------   | |                                     |    |         |     |      |
| | HybridModel         | | |            |HybridModel                |  | |                                     |    |         |     |      |
| |                     | | |            |                           |  |  -------------------------------------      ---------------------- 
| |                     | | |            |                           |  |                                                      |     |
| |   fc = DDP(Linear)  | | |            |      fc = DDP(Linear())   |  |                                                      |     |
| |                     | | |            |                           |  |                                                      |     |
| |   remote_emb_module | | |            |      remote_emb_module -------------------------------------------------------------      |
| |                     | | |            |                           |  |                                                            |
|  ---------------------  | |             ---------------------------   |                                                            |
|               |         | |                                           |                                                            |
 -------------------------   -------------------------------------------                                                             |
                |                                                                                                                    |
                 -------------------------------------------------------------------------------------------------------------------- 

手机如下:

0x05 训练

5.1 初始化

之前初始化时候,我们漏过了trainer的初始化,这里我们分析一下。

我们先看看 Trainer 上的设置。

  • 首先,trainer使用远程模块(remote module)和自己的rank 来创建上面提到的 HybridModel,远程模块持有参数服务器上的嵌入表。
  • 其次,我们需要得到一个RRef 列表,该列表指向我们想要使用DistributedOptimizer优化的所有参数。
    • 要从参数服务器嵌入表之中拿到这些参数,我们可以调用 RemoteModule 的remote_parameters,它会遍历嵌入表的所有参数并返回一个 RRef 列表。trainer通过 RPC 在参数服务器上调用此方法来得到所需参数的 RRef 列表。
    • 由于 DistributedOptimizer 始终持有一个需要优化参数的 RRef 列表,因此我们需要为 FC 层的局部参数创建 RRef。这是通过遍历model.fc.parameters()来完成的,其将为每个参数创建一个 RRef 并将其附加到从remote_parameters()返回的列表中。
    • 请注意,我们不能使用model.parameters(),因为它会递归调用model.remote_emb_module.parameters(),而RemoteModule不支持这种操作。
  • 最后,我们使用所有 RRef 创建我们的 DistributedOptimizer 并定义一个 CrossEntropyLoss 函数。
代码语言:javascript复制
def _run_trainer(remote_emb_module, rank):
    r"""
    Each trainer runs a forward pass which involves an embedding lookup on the
    parameter server and running nn.Linear locally. During the backward pass,
    DDP is responsible for aggregating the gradients for the dense part
    (nn.Linear) and distributed autograd ensures gradients updates are
    propagated to the parameter server.
    """

    # Setup the model.
    model = HybridModel(remote_emb_module, rank)

    # Retrieve all model parameters as rrefs for DistributedOptimizer.

    # Retrieve parameters for embedding table.
    model_parameter_rrefs = model.remote_emb_module.remote_parameters()

    # model.fc.parameters() only includes local parameters.
    # NOTE: Cannot call model.parameters() here,
    # because this will call remote_emb_module.parameters(),
    # which supports remote_parameters() but not parameters().
    for param in model.fc.parameters(): 
        model_parameter_rrefs.append(RRef(param)) # 这里添加了需要分布式优化的 DDP 的参数

    # Setup distributed optimizer
    opt = DistributedOptimizer(
        optim.SGD,
        model_parameter_rrefs, # dense参数和sparse参数一起分布式优化
        lr=0.05,
    )

    criterion = torch.nn.CrossEntropyLoss()

我们逻辑拓展如下,这里省略了 trainer 0 指向 参数服务器的箭头,与上图相比,增加了 DistributedOptimizer。

代码语言:javascript复制
                                            torch.multiprocessing.spawn
                                                        
                                                       |
                                                       |
                ----------------------------------------------------------------------------------- ---------------------------------- 
               |                                       |                                           |                                  |
               |                                       |                                           |                                  |
               v                                       v                                           v                                  v
 -------------- -------------   ----------------------- -------------------   --------------------- ---------------    --------------- ------------- 
|trainer 0          rank = 0 | | trainer 1                        rank = 1 | | master                     rank = 2 |  |  ps                rank = 3 |
|                            | |                                           | |                                     |  |                             |
|                            | |                                           | |                                     |  |      rpc.init_rpc           |
| dist.init_process_group    | | dist.init_process_group                   | |   rpc.init_rpc                      |  |                             |
|                            | |                                           | |                                     |  |     ----------------------  |
| rpc.init_rpc               | | rpc.init_rpc                              | |                            1        |  |    | RemoteModule         | |
|                            | |                                           | |   remote_emb_module  ---------------------> |                      | |
|  ------------------------  | |  ---------------------------------------  | |                                     |  |    |                      | |
| | _run_trainer           | | | | _run_trainer                          | | |                                     |  |    |  remote_parameters() | |
| |                        | | | |                                       | | |   fut = rpc.rpc_async(_run_trainer) |  |    |                      | |
| |                        | | | |   output = model(indices, offsets)    | | |                                     |  |    |                      | |
| |                        | | | |   dist_autograd.backward              | | |                                     |  |     ------ -------- ------  |
| |                        | | | |   opt.step                            | | |                                     |  |           ^        ^        |
| |                        | | | |                                       | | |                                     |  |           |        |        |
| |  -------------------   | | | |                                       | |  -------------------------------------    ----------------------------- 
| | | HybridModel       |  | | | |   -----------------------------       | |                                                      |        |
| | |                   |  | | | |  | HybridModel                 |      | |                                                      |        |
| | | fc = DDP(Linear)  |  | | | |  |                             |      | |                                                      |        |
| | | remote_emb_module |  | | | |  |  fc = DDP(Linear().cuda()   |      | |                                                      |        |
| | |                   |  | | | |  |  remote_emb_module ------------------------------------------------------------------------->        |
| |  -------------------   | | | |  |                             |      | |                             2                                 |
| |                        | | | |   -----------------------------       | |                                                               |
| |  --------------------  | | | |   -----------------------------       | |                                                               |
| | |DistributedOptimizer| | | | |  |DistributedOptimizer         |      | |                                                               |
| |  --------------------  | | | |  |                              ------------------------------------------------------------------------>
| |                        | | | |   -----------------------------       | |                              3
|  ------------------------  | |  ---------------------------------------  |
 ----------------------------   ------------------------------------------- 

手机如下:

5.2 训练循环

现在我们介绍在每个trainer上运行的主训练循环。这里 get_next_batch只是一个辅助函数,用于生成随机输入和训练目标。我们为多个epoch和每个batch运行该训练循环:

  1. 为Distributed Autograd.设置Distributed Autograd Context 。
  2. 运行模型的前向传播并拿到其输出。
  3. 使用损失函数根据我们的输出和target来计算损失。
  4. 使用 Distributed Autograd 对损失执行分布式反向传播。
  5. 最后,运行分布式优化器step 来优化所有参数。
代码语言:javascript复制
    def get_next_batch(rank):
        for _ in range(10):
            num_indices = random.randint(20, 50)
            indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)

            # Generate offsets.
            offsets = []
            start = 0
            batch_size = 0
            while start < num_indices:
                offsets.append(start)
                start  = random.randint(1, 10)
                batch_size  = 1

            offsets_tensor = torch.LongTensor(offsets)
            target = torch.LongTensor(batch_size).random_(8).cuda(rank)
            yield indices, offsets_tensor, target

    # Train for 100 epochs
    for epoch in range(100):
        # create distributed autograd context
        for indices, offsets, target in get_next_batch(rank):
            with dist_autograd.context() as context_id:
                output = model(indices, offsets)
                loss = criterion(output, target)

                # Run distributed backward pass
                dist_autograd.backward(context_id, [loss])

                # Tun distributed optimizer
                opt.step(context_id)

                # Not necessary to zero grads as each iteration creates a different
                # distributed autograd context which hosts different grads
        print("Training done for epoch {}".format(epoch))

因为篇幅所限,我们只是把上面的trainer再细化如下图:

  1. 初始化时候,调用 dist.init_process_group 来初始化 DistributedDataParallel,调用 rpc.init_rpc 来初始化 RPC。
  2. HybridModel 之中,fc 是DistributedDataParallel方式,remote_emb_module 是参数服务器上的 RemoteModule。
  3. DistributedOptimizer 之中,对于 HybridModel 的 fc 和 remote_emb_module 都会进行分布式优化。
  4. _run_trainer 之中,使用 model(indices, offsets) 进行前向传播,其中会调用到 HybridModel.forward。
  5. HybridModel.forward 之中则对embedding 和 fc 进行操作。
    1. embedding 是利用RPC 和 参数服务器。
    2. fc 是利用 DistributedDataParallel。
    3. 将嵌入表放在参数服务器上,并使用DistributedDataParallel 在多个trainer之间复制 FC 层。

这些序号与下图中数字对应。

代码语言:javascript复制
 --------------------------------------------------------------------- 
| trainer 1                                                 rank = 1  |
|                 -----------------------------------                 |
|                |    dist.init_process_group      1 |                |
|                |                                   |                |
|                |    rpc.init_rpc                   |                |
|                |                                   |                |
|                 -----------------------------------                 |
|  -----------------------------------------------------------------  |
| | _run_trainer                                                    | |
| |                                                                 | |
| |     output = model(indices, offsets)                            | |
| |     dist_autograd.backward                                      | |
| |     opt.step                    |                               | |
| |   -----------------------------------------------------------   | |
| |  | HybridModel                  |                          2 |  | |
| |  |                              |                            |  | |
| |  |    fc = DDP(Linear().cuda()  |                            |  | |
| |  |                              |4                           |  | |
| |  |    remote_emb_module         |                            |  | |
| |  |                              |                            |  | |
| |  |                              v                            |  | |
| |  |    -------------------------- --------------------------  |  | |
| |  |   |forward                                              | |  | |
| |  |   |  emb_lookup = remote_emb_module.forward()           | |  | |
| |  |   |                                                     | |  | |
| |  |   |                  |  5                               | |  | |
| |  |   |                  |                                  | |  | |
| |  |   |                  v                                  | |  | |
| |  |   |  fc(emb_lookup.cuda(device)                         | |  | |
| |  |   |                                                     | |  | |
| |  |    -----------------------------------------------------  |  | |
| |   -----------------------------------------------------------   | |
| |   -----------------------------------------------------------   | |
| |  | DistributedOptimizer                                    3 |  | |
| |  |                                                           |  | |
| |  |         HybridModel.remote_emb_module.remote_parameters() |  | |
| |  |                                                           |  | |
| |  |         HybridModel.fc.parameters()                       |  | |
| |  |                                                           |  | |
| |   -----------------------------------------------------------   | |
|  -----------------------------------------------------------------  |
 --------------------------------------------------------------------- 

手机如下:

注,可以在此处找到整个示例的源代码。

0x06 比对

我们已经看了三篇PyTorch官方样例,里面对参数服务器的实现各有不同。对于本文来说,又加入了一个master作为协调者来统一各个worker。

总的来说,在PyTorch 之中,因为有了 RPC 机制,所以PyTorch 的参数服务器实现比 ps-lite, paracel 更佳灵活机动:

  • 首先参数服务器目前可以放在 GPU 之中。
  • 其次,可以在参数服务器只放置参数,也可以运行优化代码,甚至可以在参数服务之上启动控制trainer。
  • 具体优化器根据实际需要,可以是普通优化器,也可以是DistributedOptimizer。
  • 训练代码从用户编写角度看则完全是运行在本地。

0xFF 参考

COMBINING DISTRIBUTED DATAPARALLEL WITH DISTRIBUTED RPC FRAMEWORK

0 人点赞