[源码解析] PyTorch 分布式之弹性训练(6)---监控/容错

2022-05-09 16:16:53 浏览数 (1)

源码解析 PyTorch 分布式之弹性训练(6)---监控/容错

目录

  • [源码解析] PyTorch 分布式之弹性训练(6)---监控/容错
    • 0x00 摘要
    • 0x01 总体逻辑
      • 1.1 Node集群角度
      • 1.2 Agent总体逻辑图
      • 1.3 监控角度
    • 0x02 多进程
      • 2.1 启动workers
        • 2.1.1 start_processes
        • 2.1.2 RunResult
      • 2.1 TE 使用
      • 2.2 PContext
      • 2.3 MultiprocessContext
        • 2.3.1 start
        • 2.3.2 wait
        • 2.3.3 _poll
      • 2.4 ProcessContext
        • 2.4.1 start_processes
        • 2.4.2 ProcessContext
      • 2.5 总结
    • 0x03 监控机制
      • 3.1 监控
      • 3.2 处理
    • 0x04 训练结束
      • 4.1 统一完成
      • 4.2 同步
    • 0x05 错误处理
      • 5.1 错误类型
      • 5.1 错误处理模式
      • 5.2 处理机制
        • 5.2.1 重启
        • 5.2.2 停止
      • 5.4 其他代理重启
    • 0xFF 参考

0x00 摘要

关于PyTorch弹性训练,迄今为止我们已经分别介绍了 Agent 和 rendezous,但是有些部分并没有深入,比如监控,本文就把它们统一起来,对弹性训练做一个整体逻辑上的梳理。

弹性训练系列文章如下:

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

[源码解析] PyTorch 分布式之弹性训练(3)---代理

[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

[源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

0x01 总体逻辑

我们需要从几个角度来看看系统逻辑,大致是从上到下,由整体到局部。

1.1 Node集群角度

我们首先从 Node 集群角度看看,可以认为是从上到下来鸟瞰弹性系统。在这种视角下,每个Node 上运行一个 Agent,Agent之中包含一个 rendezous,负责分布式协商,Agent 同时负责启动workers,监控 workers。

1.2 Agent总体逻辑图

我们然后深入到代理内部,由前文得知,目前总体逻辑如下图。

  • 1)调用 _initialize_workers 来启动 worker 进程,也就是启动了多个进程并行执行用户程序进行训练。
    • 2)调用 _rendezvous,其内部:
      • 调用 next_rendezvous 处理成员关系变化,
      • 调用 _assign_worker_ranks 为 worker 建立 ranks。
    • 3)调用 _start_workers 启动 workers。
  • 4)调用 _monitor_workers 监控这些进程的运行结果。

1.3 监控角度

弹性训练最核心的就是监控/动态处理,所以我们深入到监控模块内部进行分析。从监控的角度看,代理 Agent 主循环 _invoke_run 具体逻辑如下:

  • 调用 _initialize_workers 启动 workers。
    • 调用 _rendezvous,其内部:
      • 调用 next_rendezvous 处理成员关系变化,
      • 调用 _assign_worker_ranks 为 worker 建立 ranks。
    • 调用 _start_workers 启动 workers。
  • 程序进入 while 循环,然后通过 _monitor_workers 定期轮训监控用户程序运行情况,依据情况作出判断。
  • 如果 worker 进程出错或者不健康,进入到 elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: 这里。
    • 首先调用 _restart_workers 进行重启启动新的rendezvous,并重新启动worker进程。
    • 如果超过最大重启次数,则关闭任务。
  • 如果程序正常运行,进入到 state == WorkerState.HEALTHY 这里。
    • 如果是scale up,则有新的节点在waiting,就重启所有workers。

具体代码如下:

代码语言:javascript复制
    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
        # NOTE: currently only works for a single role

        spec = self._worker_group.spec
        role = spec.role

        self._initialize_workers(self._worker_group) # 启动worker
        monitor_interval = spec.monitor_interval
        rdzv_handler = spec.rdzv_handler

        while True:
            assert self._worker_group.state != WorkerState.INIT
            # 定期监控
            time.sleep(monitor_interval)
            # 监控客户程序运行情况
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state # 进程运行情况
            self._worker_group.state = state

            if state == WorkerState.SUCCEEDED:
                # 程序正常结束
                self._exit_barrier() # 有一个成功了就全部结束
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程序出错
                if self._remaining_restarts > 0: # 重试
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group) # 进行重启
                else:
                    self._stop_workers(self._worker_group) # 重试次数达到,结束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
								# 程序正常运行
                # 节点成员关系有变化,比如scale up
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                # 如果有新的节点在waiting,就重启所有workers
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

我们再次细化,具体如下草图:

代码语言:javascript复制
  _initialize_workers  <---------------------------------                  Node 1        Node 2                  _initialize_workers
                                                         |                           |                                    
           |                                             |                           |                                   |
           |                                             |   -----------------       |       -----------------           |
           v                                             |  |RendezvousHandler|    sync     |RendezvousHandler|          v
      _rendezvous  ---------------------------------------->                  | <---- ----> |                  <---  _rendezvous
                                      next_rendezvous    |  |                 |      |      |                 |           
           |                                             |  |                 |      |      |                 |          |
    _assign_worker_ranks                                 |  |                 |  heartbeat  |                 |          |
           |                                             |  |                 | <---- ----> |                 |
           v                                             |   -----------------       |       -----------------           v
     _start_workers                                      |                           |                              _start_workers
                                                         |                           |                                    
           |                                             |                           |                                   |
           |                                             |                           |                                   |
           v                                             |                           |                                   v
      ----- -------------------------------------------------------                  |                           -------- --------- 
     |                                                   |         |                 |                          |                  |
     |state = _monitor_workers                           |         |                 |                          |                  |
     |                                                   |         |                 |                          |                  |
     |   |                                               |         |                 |                          |                  |
     |   | UNHEALTHY,FAILED   1. Process fail            |         |                 |                          |                  |
 --> |    -----------------> _restart_workers  --        |          -->              |                          |                  |
|    |   |                                       |                 |  |              |                          |                  |
|    |   |                                        --> _stop_workers|  |              |                          |  LOOP Every 30S  |
|    |   | HEALTHY            2. Node change     |                 |  |              |                          |                  |
|    |    -----------------> _restart_workers  --                  |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   | SUCCEEDED                                               |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|    |   | 3. exit                                                 |  |              |                          |                  |
|    |   |                                                         |  |              |                          |                  |
|     -------------------------------------------------------------   |              |                          |                  |
|        |                                                            |              |                          |                  |
<---------------------------------------------------------------------               |                           -------- --------- 
         |        LOOP  Every 30S                                                    |                                   |
         |                                                                           |                                   |
         v                                                                           |                                   v
       _exit_barrier                                                                                               _exit_barrier

手机如图:

或者可以参见下图,图片来自 https://zhuanlan.zhihu.com/p/408382623。

0x02 多进程

监控机制是监控多个正在运行的训练worker,这就涉及到了多进程的启动和监控,我们需要介绍多进程。这就要从启动worker进程这个入口来看。

2.1 启动workers

_start_workers 调用 start_processes 来启动 worker 进程,默认_start_method 是 "spawn"。也就是启动了多个进程,并行执行用户程序。同时这些进程的运行结果会被监控。start_processes 参数之中,entrypointargs 是用户命令和参数,entrypoint可以是函数或者字符串。

然后,_start_workers 把 start_processes 方法启动多线程的结果保存在 _pcontext 之中,后续就用 _pcontext 来继续控制,比如结束 worker 就是直接调用 _pcontext 的 close方法。

代码语言:javascript复制
    @prof
    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
        spec = worker_group.spec        store = worker_group.store
        assert store is not None
        master_addr, master_port = super()._get_master_addr_port(store)
        restart_count = spec.max_restarts - self._remaining_restarts

        use_agent_store = spec.rdzv_handler.get_backend() == "static"

        args: Dict[int, Tuple] = {}
        envs: Dict[int, Dict[str, str]] = {}
        for worker in worker_group.workers:
            local_rank = worker.local_rank
            worker_env = {
                "LOCAL_RANK": str(local_rank),
                "RANK": str(worker.global_rank),
                "GROUP_RANK": str(worker_group.group_rank),
                "ROLE_RANK": str(worker.role_rank),
                "ROLE_NAME": spec.role,
                "LOCAL_WORLD_SIZE": str(spec.local_world_size),
                "WORLD_SIZE": str(worker.world_size),
                "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
                "ROLE_WORLD_SIZE": str(worker.role_world_size),
                "MASTER_ADDR": master_addr,
                "MASTER_PORT": str(master_port),
                "TORCHELASTIC_RESTART_COUNT": str(restart_count),
                "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
                "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
                "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
                "NCCL_ASYNC_ERROR_HANDLING": str(1),
            }
            if "OMP_NUM_THREADS" in os.environ:
                worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
            envs[local_rank] = worker_env
            worker_args = list(spec.args)
            worker_args = macros.substitute(worker_args, str(local_rank))
            args[local_rank] = tuple(worker_args)

        # scaling events do not count towards restarts (gets same attempt #)
        # remove existing log dir if this restart is due to a scaling event
        attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
        shutil.rmtree(attempt_log_dir, ignore_errors=True)
        os.makedirs(attempt_log_dir)

        assert spec.entrypoint is not None
        self._pcontext = start_processes( # 把启动多线程的结果保存在 _pcontext 之中。
            name=spec.role,
            entrypoint=spec.entrypoint, # 训练代码入口
            args=args, # 这里重要的是local rank
            envs=envs,
            log_dir=attempt_log_dir,
            start_method=self._start_method,
            redirects=spec.redirects,
            tee=spec.tee,
        )

        return self._pcontext.pids()
2.1.1 start_processes

注意,这里 start_processes 的代码在 torch/distributed/elastic/multiprocessing/api.py 之中,和后面用到的 mp的 start_processes 不同。start_processes 会从args之中提取 local rank,然后依据 local_rank 做操作,比如建立每个进程的log文件。其意义是:把每个worker进程同local_rank 联系起来,一个 local_rank 对应一个 worker进程

代码语言:javascript复制
def start_processes(
    name: str,
    entrypoint: Union[Callable, str],
    args: Dict[int, Tuple],
    envs: Dict[int, Dict[str, str]],
    log_dir: str,
    start_method: str = "spawn",
    redirects: Union[Std, Dict[int, Std]] = Std.NONE,
    tee: Union[Std, Dict[int, Std]] = Std.NONE,
) -> PContext:
    """
    Starts ``n`` copies of ``entrypoint`` processes with the provided options.
    ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
    The number of copies is determined by the number of entries for ``args`` and
    ``envs`` arguments, which need to have the same key set.

    ``args`` and ``env`` parameters are the arguments and environment variables
    to pass down to the entrypoint mapped by the replica index (local rank).
    All local ranks must be accounted for.
    That is, the keyset should be ``{0,1,...,(nprocs-1)}``.

    Args:
        name: a human readable short name that describes what the processes are
              (used as header when tee'ing stdout/stderr outputs)
        entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
        args: arguments to each replica
        envs: env vars to each replica
        log_dir: directory used to write log files
        nprocs: number of copies to create (one on each process)
        start_method: multiprocessing start method (spawn, fork, forkserver)
                      ignored for binaries
        redirects: which std streams to redirect to a log file
        tees: which std streams to redirect   print to console

    """

    # listdir raises FileNotFound or NotADirectoryError so no need to check manually
    if os.listdir(log_dir):
        raise RuntimeError(
            f"log_dir: {log_dir} is not empty, please provide an empty log_dir"
        )

    nprocs = len(args)
    _validate_full_rank(args, nprocs, "args")
    _validate_full_rank(envs, nprocs, "envs")

    # create subdirs for each local rank in the logs_dir
    redirs = to_map(redirects, nprocs)
    ts = to_map(tee, nprocs)

    # to tee stdout/stderr we first redirect into a file
    # then tail -f stdout.log/stderr.log so add tee settings to redirects
    for local_rank, tee_std in ts.items():
        redirect_std = redirs[local_rank]
        redirs[local_rank] = redirect_std | tee_std

    stdouts = {local_rank: "" for local_rank in range(nprocs)}
    stderrs = {local_rank: "" for local_rank in range(nprocs)}
    tee_stdouts: Dict[int, str] = {}
    tee_stderrs: Dict[int, str] = {}
    error_files = {}

    # 大量使用了local_rank
    for local_rank in range(nprocs):
        clogdir = os.path.join(log_dir, str(local_rank))
        os.mkdir(clogdir)

        rd = redirs[local_rank]
        if (rd & Std.OUT) == Std.OUT:
            stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
        if (rd & Std.ERR) == Std.ERR:
            stderrs[local_rank] = os.path.join(clogdir, "stderr.log")

        t = ts[local_rank]
        if t & Std.OUT == Std.OUT:
            tee_stdouts[local_rank] = stdouts[local_rank]
        if t & Std.ERR == Std.ERR:
            tee_stderrs[local_rank] = stderrs[local_rank]

        error_file = os.path.join(clogdir, "error.json")
        error_files[local_rank] = error_file
        envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file

    context: PContext
    if isinstance(entrypoint, str):
        context = SubprocessContext(
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
        )
    else:
        context = MultiprocessContext(
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
            start_method=start_method,
        )

    try:
        context.start()
        return context
    except Exception:
        context.close()
        raise
2.1.2 RunResult

工作进程的运行结果由RunResult标示。RunResult 是工作线程返回的结果。运行结果遵循"all-or-nothing"策略,其中只有当且仅当此agent管理的所有本地worker成功完成时,运行才会成功。

前面提到,把每个worker进程同local_rank 联系起来了,这想想也对,假如有5个GPU,当然就启动5个工作进程训练,这5个工作进程就对应了local rank 0~4。

但是 RunResult 注释之中注明:如果结果成功(例如is_failed() = False),则return_values字段包含此代理管理的工作进程的输出(返回值),这些工作进程由其GLOBAL ranks映射。即,result.return_values[0]是全局 rank 0的返回值。所以,在 _monitor_workers 之中会有一个从 local rank 到 gloabl rank 的映射,我们后续会讲到。

代码语言:javascript复制
@dataclass
class RunResult:
    """
    Results returned by the worker executions. Run results follow an "all-or-nothing" policy
    where the run is successful if and only if ALL local workers managed by this agent
    complete successfully.

    If the result is successful (e.g. ``is_failed() = False``) then the ``return_values``
    field contains the outputs (return values) of the workers managed by THIS agent mapped
    by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of
    global rank 0.

    .. note:: ``return_values`` are only meaningful for when the worker entrypoint
              is a function. Workers specified as a binary entrypoint do not canonically
              have a return value and the ``return_values`` field is meaningless and
              may be empty.

    If ``is_failed()`` returns ``True`` then the ``failures`` field contains the
    failure information, again, mapped by the GLOBAL rank of the worker that failed.

    The keys in ``return_values`` and ``failures`` are mutually exclusive, that is,
    a worker's final state can only be one of: succeeded, failed. Workers intentionally
    terminated by the agent according to the agent's restart policy, are not represented
    in either ``return_values`` nor ``failures``.
    """

    state: WorkerState
    return_values: Dict[int, Any] = field(default_factory=dict)
    failures: Dict[int, ProcessFailure] = field(default_factory=dict)

    def is_failed(self) -> bool:
        return self.state == WorkerState.FAILED

2.1 TE 使用

TE 使用 torch.mp 和 subprocess 包进行多进程处理。在启动多进程时候,把结果保存在 _pcontext 之中,这是一个 PContext 类型的实例。

代码语言:javascript复制
    self._pcontext = start_processes( # 把启动多线程的结果保存在 _pcontext 之中。
        name=spec.role,
        entrypoint=spec.entrypoint,
        args=args,
        envs=envs,
        log_dir=attempt_log_dir,
        start_method=self._start_method,
        redirects=spec.redirects,
        tee=spec.tee,
    )

其中,start_processes, PContext 来自如下:

代码语言:javascript复制
from torch.distributed.elastic.multiprocessing import start_processes, PContext

_monitor_workers 在监控时候,就使用 _pcontext 进行监控。在监控时候会依据线程结果转为WorkerState.FAILED,WorkerState.HEALTHY 或者WorkerState.SUCCEEDED返回给上层。

代码语言:javascript复制
@prof
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
    role = worker_group.spec.role
    worker_pids = {w.id for w in worker_group.workers}
    assert self._pcontext is not None
    pc_pids = set(self._pcontext.pids().values())
    
    result = self._pcontext.wait(0) # 对运行结果进行监控
    if result:
        if result.is_failed():
            # map local rank failure to global rank
            worker_failures = {}
            for local_rank, failure in result.failures.items():
                worker = worker_group.workers[local_rank]
                worker_failures[worker.global_rank] = failure
            return RunResult(
                state=WorkerState.FAILED, # 进程出错,返回 WorkerState.FAILED
                failures=worker_failures,
            )
        else:
            # copy ret_val_queue into a map with a global ranks
            workers_ret_vals = {}
            for local_rank, ret_val in result.return_values.items():
                worker = worker_group.workers[local_rank]
                workers_ret_vals[worker.global_rank] = ret_val
            return RunResult(
                state=WorkerState.SUCCEEDED,
                return_values=workers_ret_vals,
            )
    else:
        return RunResult(state=WorkerState.HEALTHY)

可见,PContext是关键,所以我们就看看这个类。

2.2 PContext

PContext 就是一个抽象类,实际上就是些基本配置。

代码语言:javascript复制
class PContext(abc.ABC):
    """
    The base class that standardizes operations over a set of processes
    that are launched via different mechanisms. The name ``PContext``
    is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.

    .. warning:: stdouts and stderrs should ALWAYS be a superset of
                 tee_stdouts and tee_stderrs (respectively) this is b/c
                 tee is implemented as a redirect   tail -f <stdout/stderr.log>
    """
    def __init__(
        self,
        name: str,
        entrypoint: Union[Callable, str],
        args: Dict[int, Tuple],
        envs: Dict[int, Dict[str, str]],
        stdouts: Dict[int, str],
        stderrs: Dict[int, str],
        tee_stdouts: Dict[int, str],
        tee_stderrs: Dict[int, str],
        error_files: Dict[int, str],
    ):
        self.name = name
        # validate that all mappings have the same number of keys and
        # all local ranks are accounted for
        nprocs = len(args)
        _validate_full_rank(stdouts, nprocs, "stdouts")
        _validate_full_rank(stderrs, nprocs, "stderrs")

        self.entrypoint = entrypoint
        self.args = args
        self.envs = envs
        self.stdouts = stdouts
        self.stderrs = stderrs
        self.error_files = error_files
        self.nprocs = nprocs

        self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout)
        self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr)    

但是其有两个派生类很关键:MultiprocessContext 和 SubprocessContext。前文提到,start_processes 参数之中,entrypointargs 是用户命令和参数,entrypoint可以是函数或者字符串。如果entrypoint是函数,则使用MultiprocessContext。如果是字符串类型,使用SubprocessContext。

代码语言:javascript复制
def start_processes(
    name: str,
    entrypoint: Union[Callable, str],
    args: Dict[int, Tuple],
    envs: Dict[int, Dict[str, str]],
    log_dir: str,
    start_method: str = "spawn",
    redirects: Union[Std, Dict[int, Std]] = Std.NONE,
    tee: Union[Std, Dict[int, Std]] = Std.NONE,
) -> PContext:
  
    context: PContext
    if isinstance(entrypoint, str): # 如果是字符串
        context = SubprocessContext(
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
        )
    else:
        context = MultiprocessContext( # 函数则来到这里
            name=name,
            entrypoint=entrypoint,
            args=args,
            envs=envs,
            stdouts=stdouts,
            stderrs=stderrs,
            tee_stdouts=tee_stdouts,
            tee_stderrs=tee_stderrs,
            error_files=error_files,
            start_method=start_method,
        )

    try:
        context.start() # 调用到这里
        return context
    except Exception:
        context.close()
        raise  

具体来说,两个派生类的基础不同。

  • MultiprocessContext 使用torch.multiprocessing.start_processes来启动进程。
  • SubprocessContext 使用subprocess.Popen来启动进程。

我们接下来仅使用 MultiprocessContext 来分析。

2.3 MultiprocessContext

MultiprocessContext 定义如下,其中最有意义的是 _pc 这个成员变量,其实际是 ProcessContext 这个变量。

代码语言:javascript复制
import torch.multiprocessing as mp

class MultiprocessContext(PContext):
    """
    ``PContext`` holding worker processes invoked as a function.
    """

    def __init__(
        self,
        name: str,
        entrypoint: Callable,
        args: Dict[int, Tuple],
        envs: Dict[int, Dict[str, str]],
        stdouts: Dict[int, str],
        stderrs: Dict[int, str],
        tee_stdouts: Dict[int, str],
        tee_stderrs: Dict[int, str],
        error_files: Dict[int, str],
        start_method: str,
    ):
        super().__init__(
            name,
            entrypoint,
            args,
            envs,
            stdouts,
            stderrs,
            tee_stdouts,
            tee_stderrs,
            error_files,
        )

        self.start_method = start_method
        # each ret_val queue will always contain a single element.
        self._ret_vals = {
            local_rank: mp.get_context(self.start_method).SimpleQueue()
            for local_rank in range(self.nprocs)
        }

        # see comments in ``join()`` for what this is
        self._return_values: Dict[int, Any] = {}
        self._pc: Optional[mp.ProcessContext] = None # 这里是关键
        self._worker_finished_event = mp.get_context(self.start_method).Event()
2.3.1 start

MultiprocessContext start 是调用mp.start_processes,然后保存结果。

代码语言:javascript复制
import torch.multiprocessing as mp

		def _start(self):
        if self._pc:
            raise ValueError(
                "The process context already initialized."
                " Most likely the start method got called twice."
            )
        self._pc = mp.start_processes( # 这里返回了 mp.ProcessContext
            fn=_wrap,
            args=(
                self.entrypoint,
                self.args,
                self.envs,
                self.stdouts,
                self.stderrs,
                self._ret_vals,
                self._worker_finished_event,
            ),
            nprocs=self.nprocs,
            join=False,
            daemon=False,
            start_method=self.start_method,
        )
2.3.2 wait

wait 方法是在其基类 class PContext(abc.ABC): 之中。就是循环调用 _poll 函数来定期检测

代码语言:javascript复制
    def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
        """
        Waits for the specified ``timeout`` seconds, polling every ``period`` seconds
        for the processes to be done. Returns ``None`` if the processes are still running
        on timeout expiry. Negative timeout values are interpreted as "wait-forever".
        A timeout value of zero simply queries the status of the processes (e.g. equivalent
        to a poll).
        """
        if timeout == 0:
            return self._poll()
        if timeout < 0:
            timeout = sys.maxsize

        expiry = time.time()   timeout
        while time.time() < expiry: # 定期操作
            pr = self._poll() # 用poll来检测
            if pr:
                return pr
            time.sleep(period)

        return None
2.3.3 _poll

_poll 函数是具体做检测的,调用了 torch.mp.ProcessContext.join 来做检测。torch.mp.ProcessContext 在部分/所有工作进程失败时引发异常。如果超时,则会检查工作进程状态并立即返回。因为我们使用 synchronize.Event 等待所有进程完成,所以 Join 将永远不会返回成功。

PyTorch 使用 multiprocessing.Queue 将工作进程返回值带回父进程,最后返回的结果内部就包括每个进程的运行结果。

代码语言:javascript复制
def _poll(self) -> Optional[RunProcsResult]:

    try:
        # torch.mp.ProcessContext Throws an Exception if some/all of
        # worker processes failed
        # timeout < 0 checks worker status and return immediately
        # Join will never return success since we use synchronize.Event to wait
        # for all processes to finish.
        self._pc.join(-1)

        # IMPORTANT: we use multiprocessing.Queue to carry worker return values
        # back to the parent, the worker process will wait before terminating
        # until all the buffered items are fed by the feeder thread to the underlying
        # pipe. Hence to prevent deadlocks on large return values,
        # we opportunistically try queue.get on each join call
        # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
        
        for local_rank in range(0, self.nprocs): # 遍历自己下面的进程
            return_queue = self._ret_vals[local_rank]
            if not return_queue.empty():
                # save the return values temporarily into a member var
                self._return_values[local_rank] = return_queue.get() # 得到进程运行结果

        if self._is_done():
            # we should ALWAYS have ALL the return values when all the processes are done
            self._worker_finished_event.set()
            # Wait untill all processes are finished. At this point workers finished executing user function
            self._pc.join()
            self.close()
            return RunProcsResult(
                return_values=self._return_values, # 返回进程结果
                stdouts=self.stdouts,
                stderrs=self.stderrs,
            )
        else:
            return None
          
    except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
        failed_local_rank = e.error_index

        # entrypoint for MultiprocessContext will always be a Callable
        fn_name = self.entrypoint.__qualname__  # type: ignore[union-attr]
        failed_proc = self._pc.processes[failed_local_rank]
        error_filepath = self.error_files[failed_local_rank]

        self.close()
        return RunProcsResult( # 返回进程结果
            failures={
                failed_local_rank: ProcessFailure(
                    local_rank=failed_local_rank,
                    pid=e.pid,
                    exitcode=failed_proc.exitcode,
                    error_file=error_filepath,
                )
            },
            stdouts=self.stdouts,
            stderrs=self.stderrs,
        )

2.4 ProcessContext

由前面可知,MultiprocessContext 的关键变量是:_pc: Optionalmp.ProcessContext,这个成员变量是通过 start_processes 来构建,所以我们需要看看torch.mp.ProcessContext。

2.4.1 start_processes

start_processes 在 torch/multiprocessing/spawn.py 之中,返回 ProcessContext。注意,从此之后,训练进程就会跑自己的训练代码,仿佛没有agent一样,因为agent已经把torch.distributed.launch 的工作做完了。

代码语言:javascript复制
def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
    mp = multiprocessing.get_context(start_method)
    error_queues = []
    processes = []
    for i in range(nprocs):
        error_queue = mp.SimpleQueue()
        process = mp.Process(
            target=_wrap,
            args=(fn, i, args, error_queue), # 训练进程开始跑训练代码
            daemon=daemon,
        )
        process.start()
        error_queues.append(error_queue)
        processes.append(process)

    context = ProcessContext(processes, error_queues)
    if not join:
        return context

    # Loop on join until it returns True or raises an exception.
    while not context.join():
        pass
2.4.2 ProcessContext

torch.mp.ProcessContext 才是最终发挥作用的类。其实,torch.mp.ProcessContext 的内部实现和如何启动我们并不在意,因为通过 start_processes 方法,torch.mp.ProcessContext 事实上已经启动了,我们把它当作一个功能性黑盒子即可,我们真正关心的是如何使用 torch.mp.ProcessContext 来进行监控。

从其注释中我们可以知道,torch.mp.ProcessContext在部分/所有工作进程失败时引发异常。如果超时,则会检查工作进程状态并立即返回。因为我们使用synchronize.Event等待所有进程完成,所以Join将永远不会返回成功。

代码语言:javascript复制
# torch.mp.ProcessContext Throws an Exception if some/all of
# worker processes failed
# timeout < 0 checks worker status and return immediately
# Join will never return success since we use synchronize.Event to wait
# for all processes to finish.

2.5 总结

目前关系如下:

  • 在生成时候,LocalElasticAgent 生成了 MultiprocessContext,MultiprocessContext 又生成了 ProcessContext。
  • LocalElasticAgent._pcontext 保存了 MultiprocessContextMultiprocessContext._pc 保存了 ProcessContext
  • 监控时候,LocalElasticAgent._monitor_workers 调用了 MultiprocessContext.wait,MultiprocessContext 又调用了 ProcessContext.join,ProcessContext.join 具体监控进程的运行状态,这样完成了监控的整体逻辑。
  • 子进程有变化或者超时之后,ProcessContext.join 返回了进程结果,MultiprocessContext.wait 把进程结果转发回去,_monitor_workers 把进程结果转换为 WorkerState.SUCCEEDED 或者 WorkerState.FAILED。

具体如图:

代码语言:javascript复制
 --------------------------------------------------------------------------------------     ------------------------------------     ---------------- 
| LocalElasticAgent                                                                    |   | MultiprocessContext                |   | ProcessContext |
|                                                                                      |   |                                    |   |                |
|                                                                                      |   |                                    |   |                |
|   ----------------------------------------        MultiprocessContext _pcontext      |   |       ProcessContext _pc           |   |                |
|  | _invoke_run                            |                                          |   |                                    |   |                |
|  |                                        |                                          |   |                                    |   |                |
|  |   _initialize_workers   -------------------->  _pcontext = start_processes   -------------->  start():                     |   |                |
|  |                                        |                                          |   |         _pc = mp.start_processes  ----------->          |
|  |                                        |                                          |   |                                    |   |                |
|  |   while True:                          |       --------------------------------   |   |                                    |   |                |
|  |       _monitor_workers(_worker_group) ------> | _monitor_workers               |  |   |                                    |   |                |
|  |                                        |      |                                |  |   |                                    |   |                |
|  |                                        |      |             _pcontext.wait  --------------->  wait  ---> poll:             |   |                |
|  |                                        |      |                                |  |   |                    _pc.join   --------------->          |
|   ----------------------------------------        --------------------------------   |   |                                    |   |                |
|                                                                                      |   |                                    |   |                |
 --------------------------------------------------------------------------------------     ------------------------------------     ---------------- 

手机如下:

0x03 监控机制

从前面 _monitor_workers 代码中可以看到, _monitor_workers 会把子进程运行结果转换为 WorkerState 的具体状态。当代理拿到 _monitor_workers 的监控结果之后,会根据情况进行处理。

代码语言:javascript复制
            # 监控客户程序运行情况
            run_result = self._monitor_workers(self._worker_group)
            state = run_result.state # 进程运行情况
            self._worker_group.state = state

            if state == WorkerState.SUCCEEDED:
                # 程序正常结束
                self._exit_barrier() # 有一个成功了就全部结束
                return run_result
            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程序出错
                if self._remaining_restarts > 0: # 重试
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group) # 进行重启
                else:
                    self._stop_workers(self._worker_group) # 重试次数达到,结束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
            elif state == WorkerState.HEALTHY:
								# 程序正常运行
                # 节点成员关系有变化,比如scale up
                # membership changes do not count as retries
                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                group_rank = self._worker_group.group_rank
                # 如果有新的节点在waiting,就重启所有workers
                if num_nodes_waiting > 0:
                    self._restart_workers(self._worker_group)
            else:
                raise Exception(f"[{role}] Worker group in {state.name} state")

3.1 监控

这里会调用 _pcontext.wait(0) 来获取目前 worker 子进程们的状态,然后依据返回结果,转换不同的 WorkerState 返回给调用者。这里就提到了前面讲的,RunResult 应该和 global rank 映射,所以_monitor_workers就有一个从 local rank 到 gloabl rank 的映射。

为何要使用 Global rank 作为进程状态的标示?因为在Node之间需要沟通,这时候需要用Global rank。

代码语言:javascript复制
    @prof
    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
        role = worker_group.spec.role
        worker_pids = {w.id for w in worker_group.workers} # 拿到本agent所有worker的pid
        pc_pids = set(self._pcontext.pids().values())
        if worker_pids != pc_pids:
            return RunResult(state=WorkerState.UNKNOWN)

        result = self._pcontext.wait(0) # 对运行结构进行监控
        if result:
            if result.is_failed(): # 如果进程失败
                # map local rank failure to global rank
                worker_failures = {}
                #  返回的结果内部就包括每个进程的运行结果
                for local_rank, failure in result.failures.items(): # local_rank是进程index
                    worker = worker_group.workers[local_rank] # 拿到对应的worker
                    worker_failures[worker.global_rank] = failure # 拿到其 global_rank,进而设置worker状态
                return RunResult(
                    state=WorkerState.FAILED,
                    failures=worker_failures, # 返回运行结果
                )
            else:
                # copy ret_val_queue into a map with a global ranks
                workers_ret_vals = {}
                for local_rank, ret_val in result.return_values.items():
                    worker = worker_group.workers[local_rank] # 
                    workers_ret_vals[worker.global_rank] = ret_val
                return RunResult(
                    state=WorkerState.SUCCEEDED,
                    return_values=workers_ret_vals, # 返回运行结果
                )
        else:
            return RunResult(state=WorkerState.HEALTHY)

3.2 处理

根据返回状态不同,会有不同处理:

  • 如果 WorkerState.SUCCEEDED,则说明训练结束,正常返回。
  • 如果 WorkerState.HEALTHY,则说明训练正常运行,这时候会检查是否有新节点加入,我们后文会详解。
  • 如果 WorkerState.UNHEALTHY, WorkerState.FAILED,说明训练出现问题,这里有两种情况。
    • 一种是程序出错,TE 会进行重试。
    • 一种是节点退出,我们在下文分析,但是其处理流程与程序出错一致。

接下来我们就分析一下如何处理训练结束 和 程序出错。

0x04 训练结束

代码语言:javascript复制
        if state == WorkerState.SUCCEEDED:
            # 程序正常结束
            self._exit_barrier() # 有一个成功了就全部结束
            return run_result

以上是训练正常结束时候的处理,特殊就在于_exit_barrier的使用。

4.1 统一完成

Torchelastic目前支持DDP风格的应用程序。也就是说TE希望所有 workers 大约同时完成。实际上,几乎不可能保证DDP的所有工人都能保证同时结束,所以因此TE提供了一个finalization barrier,这个barrier的作用是对worker finalization 实施等待超时(5分钟)。也就是说,如果有一个worker 训练完成,TE(torchelastic)希望用户所有worker以5分钟的误差完成。

代码语言:javascript复制
def _exit_barrier(self):
    """
    Wait for ``exit_barrier_timeout`` seconds for all agents to finish
    executing their local workers (either successfully or not). This
    acts as a safety guard against user scripts that terminate at different
    times. This barrier keeps the agent process alive until all workers finish.
    """

    start = time.time()
    try:
        store_util.barrier(
            self._store,
            self._worker_group.group_rank,
            self._worker_group.group_world_size,
            key_prefix=_TERMINAL_STATE_SYNC_ID,
            barrier_timeout=self._exit_barrier_timeout,
        )
    except Exception:
        log.exception(
            f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
        )

exit_barrier_timeout 的默认值就是300秒,即5分钟。

代码语言:javascript复制
exit_barrier_timeout: float = 300,

4.2 同步

在 torch/distributed/elastic/utils/store.py 之中,barrier 会调用 synchronize 进行同步。

代码语言:javascript复制
def barrier(
    store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
) -> None:
    """
    A global lock between agents.

    Note: Since the data is not removed from the store, the barrier can be used
        once per unique ``key_prefix``.
    """
    data = f"{rank}".encode(encoding="UTF-8")
    synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)

synchronize 则是通过store进行同步。

代码语言:javascript复制
def get_all(store, prefix: str, size: int):
    r"""
    Given a store and a prefix, the method goes through the array of keys
    of the following format: ``{prefix}{idx}``, where idx is in a range
    from 0 to size, and tries to retrieve the data.

    Usage

    ::

     values = get_all(store, 'torchelastic/data', 3)
     value1 = values[0] # retrieves the data for key torchelastic/data0
     value2 = values[1] # retrieves the data for key torchelastic/data1
     value3 = values[2] # retrieves the data for key torchelastic/data2

    """
    data_arr = []
    for idx in range(size):
        data = store.get(f"{prefix}{idx}")
        data_arr.append(data)
    return data_arr

def synchronize(
    store,
    data: bytes,
    rank: int,
    world_size: int,
    key_prefix: str,
    barrier_timeout: float = 300,
) -> List[bytes]:
    """
    Synchronizes ``world_size`` agents between each other using the underlying c10d store.
    The ``data`` will be available on each of the agents.

    Note: The data on the path is not deleted, as a result there can be stale data if
        you use the same key_prefix twice.
    """
    store.set_timeout(timedelta(seconds=barrier_timeout))
    store.set(f"{key_prefix}{rank}", data)
    agent_data = get_all(store, key_prefix, world_size)
    return agent_data

0x05 错误处理

5.1 错误类型

分布式PyTorch作业中的每个主机都运行一个TorchElastic 代理和多个worker(作为TorchElastic代理的子进程)。由于worker是用户提供的(PyTorch script/job),TorchElastic可以通过代理将错误传播到trainer之上,直至调度程序(scheduler),最终把这些作业的状态通知最终用户并应用一些重试策略。

TE 把错误归为如下几类。

代码语言:javascript复制
 ---------------- ---------------- -------------------------------------------------------------- 
| Category       | Sub-Category   |  Description                                                 |
 ================ ================ ============================================================== 
| User Error     | Input Error    | invalid inputs to TorchElastic APIs (e.g. min > max nodes)   |
|                 ---------------- -------------------------------------------------------------- 
|                | Worker Failure | any failures on the worker child process                     |
 ---------------- ---------------- -------------------------------------------------------------- 
| Platform Error |      n/a       | failures caused by the agent                                 |
 ---------------- ---------------- -------------------------------------------------------------- 
| Infra Error    |      n/a       | failures outside the domain of the agent and workers         |
|                |                | (e.g. host failures)                                         |
 ---------------- ---------------- -------------------------------------------------------------- 

5.1 错误处理模式

对应的错误处理模式如下,我们按照从小到大的故障级别来看:

  • User Error:具体又分为如下处理方式:
    • User Error :比如错误输入,这样直接程序捕获即可。
    • Worker Failure:
      • Worker Failures是特殊的,因为异常/失败源于与代理不同的进程,因此错误需要在进程间传播(例如,代理不能简单地 try-catch 一个工作进程上引发的异常)。
        • TorchElastic代理使用 torch.distributed.elastic.multiprocessing.start_processes启动worker,它内置了一个简单的基于文件的进程间错误传播。
        • 任何用record修饰的函数或二进制入口点都会将未捕获的异常(带有跟踪信息)写入环境变量 TORCHELASTIC_ERROR_FILE指定的文件。父进程(例如代理)在其启动的每个子进程之上设置此环境变量,然后聚合所有子进程的错误文件,并传播具有最小时间戳的错误文件(例如第一个错误)。
      • 文档中有如下论述:对于有“n”个 workers 的训练job,如果“k<=n”名 worker 失败,那么所有 worker 都会停止并重新启动,直到达到 “max_restarts” 次数。上面这句话的意思其实就是:如果有一个worker失败了,而且还没有达到了最大重启次,TE 将启动新的rendezvous,并且重启所有workers,因为是新的 rendezvous,所以其他 TE 代理也会重启其 workers。
      • 一个worker的失败将导致整个集群失败:如果单个worker不断失败,则会导致TE agent 的 max_restarts 变量变为零。这将导致agent完成其工作并关闭rendezvous。如果在不同的代理上有任何其他worker,它们也将被终止。
  • Platform Error(就是代理故障)
    • 非Worker故障(Worker Failure)之外的所有错误都会从代理进程中正常引发,隐式或显式地使代理进程崩溃。因此可以应用标准语言(python)提供的异常处理策略。
    • 代理失败也可以导致本地工作组失败。如何处理取决于job manager,比如使整个作业(gang语义)失败或尝试替换节点。两种行为均由代理支持。
  • Infra Error(就是节点故障 ):与代理故障同样方式来处理。

我们接下来就具体看看如何处理"Worker Failure"。

5.2 处理机制

错误处理具体机制如下,如果重试尚未达到最大次数,则试图重启workers。如果已经达到了最大次数,则停止 workers。

代码语言:javascript复制
        elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
            # 程序出错
            if self._remaining_restarts > 0: # 重试
                self._remaining_restarts -= 1
                self._restart_workers(self._worker_group) # 进行重启
            else:
                self._stop_workers(self._worker_group) # 重试次数达到,结束workers
                self._worker_group.state = WorkerState.FAILED
                self._exit_barrier()
                return run_result
5.2.1 重启

_restart_workers 会停掉所有 workers,然后重新一轮 rendezvous 。

代码语言:javascript复制
@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
    """
    Restarts (stops, rendezvous, starts) all local workers in the group.
    """

    role = worker_group.spec.role
    self._stop_workers(worker_group)
    worker_group.state = WorkerState.STOPPED
    self._initialize_workers(worker_group)
5.2.2 停止

停止 workers 就是关闭上下文。

代码语言:javascript复制
def _shutdown(self) -> None:
    if self._pcontext:
        self._pcontext.close()
        
@prof
def _stop_workers(self, worker_group: WorkerGroup) -> None:
    self._shutdown()

在 MultiprocessContext 之中,close 方法是关闭所有子进程,然后等待其全部停止。

代码语言:javascript复制
    def _close(self) -> None:
        if self._pc:
            for proc in self._pc.processes:
                proc.terminate()
                proc.join()

5.4 其他代理重启

从源码注释可知:新一轮 rendezvous 会让其他 agent 也重启它们的worker。

代码语言:javascript复制
When worker fails, TE will check the number of restarts available, if there is more than 0 restarts, TE will start a new rendezvous round and restart the worker process. New rendezvous round will other TE agents to terminate their workers.

这是如何做到的?具体如下:

  1. Agent 0(故障Agent)通过 monitoring 发现了故障。
  2. Agent 0 调用 _restart_workers 重启worker。
  3. Agent 0 会调用 next_rendezvous 发起新一轮 rendezvous。
  4. Agent 0 在做任何操作之前,比如 keep alive 操作之前,会调用 sync 来从kvstore获取集群信息,这样可以保证 Agent拿到的是集群最新状态。
  5. Agent 0 会把自己加入到本地的 waiting_list 之中。
  6. Agent 0 同时会调用 mark_dirty,意思是我状态更新了,需要写入KVStore。
  7. Agent 0 会调用sync把自己的waiting_list 被写入 KVStore。
  8. Agent 1(其他正常工作的 agent)会在做任何操作之前,比如 keep alive 操作之前,会调用 sync 操作从KVStore 获取最新信息。
  9. Agent 1 利用这些信息来更新自己的状态,这样本地 waiting_list 就会更新。
  10. Agent 1 的 train loop 在每 30 秒监控之后,因为系统正常,是 Healthy 状态。
  11. Agent 1 所以调用 num_nodes_waiting() 看看 waiting_list 数目。
  12. Agent 1 会获取本地 waiting list 的数目。
  13. 如果 waiting list 不为空,也调用_restart_workers。
  14. 其最终会调用next_rendezvous。

具体如下:

代码语言:javascript复制
 Agent 0                                      Agent 1
 ---------------------------                   -------------------------------------------- 
|    _invoke_run            |                 |                       _invoke_run          |
|                           |                 |                                            |
|          |                |                 |                           |                |
|          | 1              |                 |                           |                |
|          v                |                 |                           |                |
| Worker Process Error      |                 |                           |                |
|                           |                 |                           |                |
|          |                |                 |                           | 10             |
|          | 2              |                 |                           v                |
|          v                |                 |                        HEALTHY             |
|  _restart_workers         |                 |                                            |
|                           |                 |                           | 11             |
|          |                |                 |                           |                |
|          | 3              |                 |                           v                |
|          v                |                 |               -->  num_nodes_waiting() > 0 |
|   next_rendezvous         |                 |              |                             |
|                           |                 |              |            |                |
|          | 4              |                 |              | 12         | 13             |
|          |                     ----------   |              |            v                |
|          v      cluster info  |          |  |              |       _restart_workers      |
|        sync  <------------ -> | KV Store |  |              |                             |
|                           |   |          |  |              |            |                |
|          | 5              |   |          |  |              |            | 14             |
|          v                |   |          |  |              |            v                |
|  Add to local waiting_list|   |          |  |              |        next_rendezvous      |
|                           |   |          |  |              |                             |
|          |                |   |          |  |              |                             |
|          | 6              |   |          |  |              v                             |
|          v                |   |          |  |                                            |
|     mark_dirty            |   |          |  |  Add to local waiting_list                 |
|                           |   |          |  |              ^                             |
|          |                |   |          |  |              |                             |
|          | 7              |   |          |  |            9 | waiting_list                |
|          v         7      |   |          |  |    8                                       |
|        sync  ---------------> |           --------------> sync                           |
|              waiting_list |   |          |  |waiting_list                                |
|                           |    ----------   |                                            |
 ---------------------------                   -------------------------------------------- 

至此,我们监控机制初步介绍完成,因为篇幅所限,我们下一篇继续介绍Scale up/down如何处理。

0xFF 参考

云原生的弹性 AI 训练系列之二:PyTorch 1.9.0 弹性分布式训练的设计与实现

PyTorch Elastic源码阅读

0 人点赞