[源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

2021-12-28 13:40:13 浏览数 (1)

源码解析 PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

目录

  • [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑
    • 0x00 摘要
    • 0x01 总体背景
    • 0x02 基本概念
      • 2.1 Barrier
      • 2.2 排他性(Exclusivity)
      • 2.3 一致性(Consistency)
      • 2.4 容错(Fault-tolerance)
      • 2.5 共享键值存储
      • 2.6 等待worker和rendezvous关闭
      • 2.7 DynamicRendzvousHandler
      • 2.8 问题&设计
    • 0x03 静态结构
      • 3.1 启动参数
      • 3.2 配置
      • 3.3 状态
      • 3.4 节点
      • 3.5 后端
        • 3.5.1 使用
        • 3.5.2 基类
        • 3.5.3 创建
          • 3.5.3.1 TCPStore
          • 3.5.3.2 C10dRendezvousBackend
      • 3.6 StateHolder
        • 3.6.1 _RendezvousStateHolder
        • 3.6.2 _BackendRendezvousStateHolder
        • 3.6.3 如何使用
      • 3.7 总结
    • 0x04 动态逻辑
      • 4.1 入口
      • 4.2 基类 RendezvousHandler
      • 4.3 注册
        • 4.3.1 RendezvousHandlerRegistry
        • 4.3.2 全局registry
      • 4.4 创建
        • 4.4.1 静态 RendezvousHandler
          • 4.4.1.1 _create_static_handler
          • 4.4.1.2 StaticTCPRendezvous 子类
        • 4.4.2 动态 RendezvousHandler
          • 4.4.2.1 _create_c10d_handler
          • 4.4.2.2 from_backend
          • 4.4.2.3 DynamicRendezvousHandler
          • 4.4.2.4 next_rendezvous
          • 4.4.2.5 _get_world
      • 4.5 容错
        • 4.5.1 ETCD
        • 4.5.2 DynamicRendezvousHandler
      • 4.6 小结
    • 0x05 总结
    • 0xFF 参考

0x00 摘要

在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第四篇,看看Rendezvous 的结构和总体逻辑。

弹性训练系列文章如下:

[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

[源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

[源码解析] PyTorch 分布式之弹性训练(3)---代理

0x01 总体背景

TE 是围绕在 Rendezvous 基础之上的多个elastic agent构成,这是一种功能分离,让我们对比一下看看。

  • Agent 偏重具体节点上的逻辑。
    • Agent 负责具体业务逻辑相关操作,比如启动进程执行用户程序,监控用户程序运行情况,如果有异常就通知 Rendezvous。
    • Agent 是一个 worker manager,负责启动/管理 workers 进程,组成一个 worker group,监控 workers 运行状态,捕获失效 workers,如果有故障/新加入worker,则重启 worker group。
    • Agent负责维护 WORLD_SIZE 以及 RANK 信息。用户不需要再手动提供,Agent会自动处理这些。
    • Agent 是具体节点上的后台进程,是独立个体。Agent自己无法实现整体上的弹性训练,所以需要一个机制来完成 worker 之间的相互发现,变更同步等等(WORLD_SIZE 和 RANK 这些信息其实也需要多个节点同步才能确定),这就是下面的 Rendezvous 概念。
  • Rendezvous 负责集群逻辑,保证节点之间对于""有哪些节点参与训练"达成强一致共识。
    • 每一个 Agent 内部包括一个 Rendezvous handler,这些 handler 总体上构成了一个 Rendezvous 集群,从而构成了一个 Agent 集群。
    • Rendezvous 完成之后,会创建一个共享键值存储(shared key-value store),这个store实现了一个torch.distributed.Store API。此存储仅由已完成Rendezvous的成员共享,它旨在让Torch Distributed Elastic在初始化作业过程之中交换控制和数据信息。
    • Rendezvous 负责在每个agent之上维护当前 group 所有相关信息。每个 agent 之上有一个 rendezvous,它们会互相通信,总体维护一套信息,这些信息存储在上面提到的Store 之中。
    • Rendezvous 负责集群逻辑相关,比如新加入节点,移除节点,分配rank等等。

0x02 基本概念

在 Torch Distributed Elastic 上下文之中,人们使用 rendezvous 这个术语来特指一个特定功能:一个结合了对等发现(peer discovery)的分布式同步(distributed synchronization)原语。

其可以理解为一个分布式治理过程:Rendezvous 被Torch Distributed Elastic用来收集一个训练job的参与者(节点),这样,参与者们可以商议得到参与者列表和每个参与者的角色,也可以对训练何时开始/恢复做出一致的集体决定。即,通过 rendezvous,系统对参与者达成共识,给每一个参与者分配 rank,local rank,通知 world size等等,当需要弹性伸缩或者出现故障时候,就会重新进行 rendezvous 操作。

为了实现弹性训练,需要有一个节点/进程之间彼此发现的机制。在TorchElastic中,rendezvous就是这个发现机制或者说同步组件,其被用来作为对等发现的分布式同步(治理)机制,用于同步、收集各个worker的信息,包括节点列表、各节点worker角色等,然后各个Agent才能共同决定训练的开始、结束、恢复等。

图片来自 PyTorch 源码。

或者使用 TE 源码之中的图片,能更清楚的看出来这是三个Node。

Rendezvous会提供以下细分功能。

2.1 Barrier

执行会合的节点将全部阻塞到 rendezvous 完成,即至少有min个节点(针对同一作业)已加入到Barrier,这也意味着对于固定大小的节点数目,barrier是不必要的。

在达到"min"数量后,rendezvous 不会立刻宣布完成,而是会等待额外的一小段时间,这用来保证rendezvous不会"过快"完成,因为如果立刻完成,就会错过那些加入时只慢了一点点的节点。当然如果在Barrier处聚集了max个节点,则rendezvous立即完成。

另外,还有一个总超时时间配置 :如果在超时时间之内 min个节点一直没有达到,则会导致 rendezvous 失败,这是一个简单的故障安全(fail-safe)解决方案,用来帮助释放部分分配的作业资源,防止资源浪费。

2.2 排他性(Exclusivity)

一个简单的分布式屏障是不够的,因为我们还需要确保在任何给定的时间(对于给定的作业)只存在一组节点。换言之,对于同一个job,新节点(即后来加入的节点)不能组成一个新的并行的独立worker group。

Torch Distributed Elastic 会确保如果一组节点已经完成rendezvous(可能已经在训练),那么其他试图加入的"迟到"节点只会被认为是等待状态,且必须等到现有rendezvous被结束。

2.3 一致性(Consistency)

rendezvous完成后,其所有成员将对工作成员资格以及每个人在其中的角色(role)达成共识。此角色(role)使用一个介于 0 ~ world size 之间的整型来表示,被称之为rank。

请注意,rank是不稳定的,比如,同一个的节点在下一次(重新)rendezvous中可能被分配了不同的rank。

2.4 容错(Fault-tolerance)

Torch Distributed Elastic rendezvous 在 rendezvous 过程中有容错机制:

  • 在开始join rendezvous 和 rendezvous 完成之间,如果有进程崩溃(或网络故障等),就会自动引发一个re-rendezvous,剩余健康节点会自动重组。
  • 节点也可能在rendezvous 完成后失败(或被其他节点观察到失败),这个场景由Torch Distributed Elastic train_loop 负责,也会触发一个re-rendezvous,训练过程不会中断。

2.5 共享键值存储

Rendezvous 完成后,将创建一个共享键值存储(key-value store)并返回给node。此存储实现了一个torch.distributed.store API(请参见https://pytorch.org/docs/stable/distributed.html)。

此存储仅由已完成rendezvous的成员共享,被Torch Distributed Elastic用作交换初始化作业控制和数据平面所必需的信息。

2.6 等待worker和rendezvous关闭

Torch Distributed Elastic rendezvous handler提供了额外功能:

  • 查询在barrier之后有多少worker加入(迟到了),他们将在下一次rendezvous 中参与进来。
  • 设定 rendezvous 为关闭状态,以通知所有节点不参与下一次rendezvous 。

2.7 DynamicRendzvousHandler

Torch Distributed Elastic 提供了DynamicRendzvousHandler类,该类实现了上述的 rendezvous mechanism。

这个类需要我们在构建时候指定后端(RendezvousBackend)。用户可以自己实现后端,或者使用如下PyTorch附带实现之一:

  • C10dRendezvousBackend,其使用 C10d 存储(默认是 TCPStore) 作为 rendezvous backend,其优势是不需要依赖第三方,比如etcd,来构建一个rendezvous 。
  • EtcdRendezvousBackend,其使用EtcdRendezvousHandler,EtcdRendezvousBackend 等类来基于 etcd 完成,缺点是需要搭建 Etcd。

比如:

代码语言:javascript复制
     store = TCPStore("localhost")
     backend = C10dRendezvousBackend(store, "my_run_id")
     rdzv_handler = DynamicRendezvousHandler.from_backend(
         run_id="my_run_id",
         store=store,
         backend=backend,
         min_nodes=2,
         max_nodes=4
     )

2.8 问题&设计

知道了所需要实现的功能,我们就可以思考 Rendezvous 应该具备哪些内部模块,才能满足这些需求。

  • 需要有一个节点概念,这样才能把系统表达出来。
  • 需要有一个状态概念,就是节点的状态。
  • 需要有一个总体静态类,用来把节点,状态以及其他信息统一维护起来。
  • 需要有一个共享共享键值存储,可以集中保存上述信息,也可以用来彼此交换信息,达成共识。
  • 需要有一个动态server,或者handler,其提供一套API以供外界访问。

我们就按照这个思路分析,首先看看静态结构,然后看看动态逻辑。

0x03 静态结构

我们接下来看看相关支撑系统。这里要注意的是,elastic 内部有一套 Rendezvous,和 distributed 原有的 Rendezvous 那套不一样,别搞混了。distributed 原有的 Rendezvous 就是一套简单的 KV 存储。elastic Rendezvous 则要复杂得多。

我们仔细分析一下 Rendezvous 的支撑系统。

3.1 启动参数

RendezvousParameters 是构建RendezvousHandler所需参数。

  • backend :后端名称。
  • endpoint :端点,格式是 :。
  • run_id : rendezvous 的 id。
  • min_nodes :rendezvous 的最小节点数目。
  • max_nodes :rendezvous 的最大节点数目。
  • kwargs :后端的附加参数。
代码语言:javascript复制
class RendezvousParameters:
    """Holds the parameters to construct a :py:class:`RendezvousHandler`.

    Args:
        backend:
            The name of the backend to use to handle the rendezvous.
        endpoint:
            The endpoint of the rendezvous, usually in form <hostname>[:<port>].
        run_id:
            The id of the rendezvous.
        min_nodes:
            The minimum number of nodes to admit to the rendezvous.
        max_nodes:
            The maximum number of nodes to admit to the rendezvous.
        **kwargs:
            Additional parameters for the specified backend.
    """

    def __init__(
        self,
        backend: str,
        endpoint: str,
        run_id: str,
        min_nodes: int,
        max_nodes: int,
        **kwargs,
    ):
        if not backend:
            raise ValueError("The rendezvous backend name must be a non-empty string.")

        if min_nodes < 1:
            raise ValueError(
                f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
            )
        if max_nodes < min_nodes:
            raise ValueError(
                f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
                f"equal to the minimum number of rendezvous nodes ({min_nodes})."
            )

        self.backend = backend
        self.endpoint = endpoint
        self.run_id = run_id
        self.min_nodes = min_nodes
        self.max_nodes = max_nodes
        self.config = kwargs

3.2 配置

RendezvousSettings 类用来存储rendezvous的配置。可以理解为静态元信息。

  • run_id : rendezvous 的 id。
  • min_nodes :rendezvous 的最小节点数目。
  • max_nodes :rendezvous 的最大节点数目。
  • timeout :超时时间。
  • keep_alive_interval :节点在发送心跳之间等待的时间量。
  • keep_alive_max_attempt : 心跳的最大重试次数。
代码语言:javascript复制
@dataclass(repr=False, eq=False, frozen=True)
class RendezvousSettings:
    """Holds the settings of the rendezvous.

    Attributes:
        run_id:
            The run id of the rendezvous.
        min_nodes:
            The minimum number of nodes to admit to the rendezvous.
        max_nodes:
            The maximum number of nodes to admit to the rendezvous.
        timeout:
            The timeout configuration of the rendezvous.
        keep_alive_interval:
            The amount of time a node waits before sending a heartbeat to keep
            it alive in the rendezvous.
        keep_alive_max_attempt:
            The maximum number of failed heartbeat attempts after which a node
            is considered dead.
    """

    run_id: str
    min_nodes: int
    max_nodes: int
    timeout: RendezvousTimeout
    keep_alive_interval: timedelta
    keep_alive_max_attempt: int

3.3 状态

_RendezvousState 是rendezvous的状态。是动态信息,每一个 node 都会维护一个本地 state。

  • round:Rendezvous的当前轮次
  • complete:一个布尔值,指示rendezvous当前一轮是否完成了。
  • deadline:截止时间,如果如果当前轮次一直在等待节点加入,如果这个参数设置了,就是等待的截至时间。
  • closed:一个布尔值,指示rendezvous是否结束了。
  • participants:字典结构,存放参与者和它们对应ranks。
  • wait_list:set结构,存放等待参与下一轮rendezvous操作的一组节点
  • last_heartbeats:字典,包含每个节点上次心跳时间
代码语言:javascript复制
class _RendezvousState:
    """Holds the state of a rendezvous.

    Attributes:
        round:
            The current round of the rendezvous.
        complete:
            A boolean value indicating whether the current round of the
            rendezvous is complete.
        deadline:
            The time at which the current round of the rendezvous will be
            considered complete if it is still waiting for nodes to join.
        closed:
            A boolean value indicating whether the rendezvous is closed.
        participants:
            A dictionary of the participants and their corresponding ranks.
        wait_list:
            A set of nodes that are waiting to participate in the next round of
            the rendezvous.
        last_heartbeats:
            A dictionary containing each node's last heartbeat time.
    """

    round: int
    complete: bool
    deadline: Optional[datetime]
    closed: bool
    participants: Dict[_NodeDesc, int]
    wait_list: Set[_NodeDesc]
    last_heartbeats: Dict[_NodeDesc, datetime]

    def __init__(self) -> None:
        self.round = 0
        self.complete = False
        self.deadline = None
        self.closed = False
        self.participants = {}
        self.wait_list = set()
        self.last_heartbeats = {}

3.4 节点

_NodeDesc 是rendezvous的一个节点。

代码语言:javascript复制
@dataclass(eq=True, order=True, frozen=True)
class _NodeDesc:
    """Describes a node in the rendezvous.

    Attributes:
        fqdn:
            The FQDN of the node.
        pid:
            The id of the process in which the rendezvous handler runs.
        local_id:
            A process-wide unique id.
    """

    fqdn: str
    pid: int
    local_id: int

    def __repr__(self) -> str:
        return f"{self.fqdn}_{self.pid}_{self.local_id}"

3.5 后端

在 PyTorch 之中,backend 概念指的是当前进程要使用的通信后端,一般来说,支持的通信后端有 gloompinccl 。建议用 nccl

在弹性训练这里,DynamicRendezvousHandler 需要我们在构建时候指定后端(RendezvousBackend)。用户可以自己实现后端,或者使用如下PyTorch附带实现之一:

  • C10dRendezvousBackend,其使用 C10d 存储(默认是 TCPStore) 作为 rendezvous backend,其优势是不需要依赖第三方,比如etcd,来构建一个rendezvous 。
  • EtcdRendezvousBackend,其使用EtcdRendezvousHandler,EtcdRendezvousBackend 等类来基于 etcd 完成。

因为 EtcdRendezvousBackend 必须依赖 ETCD,需要安装一个 ETCD集群,所以推荐使用 c10d 后端,其易用性更好。我们接下来就主要介绍 c10d 后端。

C10d 后端主要基于一个 TCPStore,通过 TCP 进行同步。我们在之前文章中介绍过 TCPStore,TCPStore 是基于 TCP 的分布式键值存储实现(类似于 Redis)。是一个典型的 client-server 架构,服务器存储/保存数据,而存储客户端可以通过 TCP 连接到服务器存储并执行诸如set()插入键值对、get()检索键值对等操作。

所以,对于 c10d 后端来说,在其中一个代理之上会运行 TCPStore Master,其负责监听端口,提供API,Rendezvous 的各种同步操作,都是由各个代理连接到这个中心化的 TCPStore Master,在其上完成。

具体可以如下图所示,来源于知乎 BobLiu。

3.5.1 使用

下图展示了如何配置后端

代码语言:javascript复制
     store = TCPStore("localhost")
     backend = C10dRendezvousBackend(store, "my_run_id") # 配置了后端

     rdzv_handler = DynamicRendezvousHandler.from_backend(
         run_id="my_run_id",
         store=store,
         backend=backend,
         min_nodes=2,
         max_nodes=4
     )
3.5.2 基类

我们首先看看后端的基类 RendezvousBackend。这是一个虚类,主要功能就是设置和获取State。

代码语言:javascript复制
class RendezvousBackend(ABC):
    """Represents a backend that holds the rendezvous state."""

    @property
    @abstractmethod
    def name(self) -> str:
        """Gets the name of the backend."""

    @abstractmethod
    def get_state(self) -> Optional[Tuple[bytes, Token]]:
        """Gets the rendezvous state.

        Returns:
            A tuple of the encoded rendezvous state and its fencing token or
            ``None`` if no state is found in the backend.
        """

    @abstractmethod
    def set_state(
        self, state: bytes, token: Optional[Token] = None
    ) -> Optional[Tuple[bytes, Token, bool]]:
        """Sets the rendezvous state.

        The new rendezvous state is set conditionally:
          - If the specified ``token`` matches the fencing token stored in the
            backend, the state will be updated. The new state will be returned
            to the caller along with its fencing token.
          - If the specified ``token`` does not match the fencing token stored
            in the backend, the state won't be updated; instead the existing
            state along with its fencing token will be returned to the caller.
          - If the specified ``token`` is ``None``, the new state will be set
            only if there is no existing state in the backend. Either the new
            state or the existing state along with its fencing token will be
            returned to the caller.

        Args:
            state:
                The encoded rendezvous state.
            token:
                An optional fencing token that was retrieved by a previous call
                to :py:meth:`get_state` or ``set_state()``.

        Returns:
            A tuple of the serialized rendezvous state, its fencing token, and
            a boolean value indicating whether our set attempt succeeded.

        Raises:
            RendezvousConnectionError:
                The connection to the backend has failed.
            RendezvousStateError:
                The rendezvous state is corrupt.
        """
3.5.3 创建

以下代码是如何创建后端。其先是生成了 tcp store,然后调用 C10dRendezvousBackend。

代码语言:javascript复制
def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
    """Creates a new :py:class:`C10dRendezvousBackend` from the specified
    parameters.

     -------------- ----------------------------------------------------------- 
    | Parameter    | Description                                               |
     ============== =========================================================== 
    | store_type   | The type of the C10d store. As of today the only          |
    |              | supported type is "tcp" which corresponds to              |
    |              | :py:class:`torch.distributed.TCPStore`. Defaults to "tcp".|
     -------------- ----------------------------------------------------------- 
    | read_timeout | The read timeout, in seconds, for store operations.       |
    |              | Defaults to 60 seconds.                                   |
     -------------- ----------------------------------------------------------- 
    | is_host      | A boolean value indicating whether this backend instance  |
    |              | will host the C10d store. If not specified it will be     |
    |              | inferred heuristically by matching the hostname or the IP |
    |              | address of this machine against the specified rendezvous  |
    |              | endpoint. Defaults to ``None``.                           |
    |              |                                                           |
    |              | Note that this configuration option only applies to       |
    |              | :py:class:`torch.distributed.TCPStore`. In normal         |
    |              | circumstances you can safely skip it; the only time when  |
    |              | it is needed is if its value cannot be correctly          |
    |              | determined (e.g. the rendezvous endpoint has a CNAME as   |
    |              | the hostname or does not match the FQDN of the machine).  |
     -------------- ----------------------------------------------------------- 
    """
    # As of today we only support TCPStore. Other store types do not have the
    # required functionality (e.g. compare_set) yet.
    store_type = params.get("store_type", "tcp").strip().lower()
    if store_type != "tcp":
        raise ValueError("The store type must be 'tcp'. Other store types are not supported yet.")

    store = _create_tcp_store(params)

    return C10dRendezvousBackend(store, params.run_id), store
3.5.3.1 TCPStore

_create_tcp_store 建立了一个 TCPStore。

代码语言:javascript复制
def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
    host, port = parse_rendezvous_endpoint(params.endpoint, default_port=29400)

    cfg_is_host = params.get_as_bool("is_host") # 获取配置看看
    # If the user has explicitly specified whether our process should host the
    # the store, respect it.
    if cfg_is_host is not None: # 如果配置了,就使用
        is_host = cfg_is_host
    # Otherwise try to determine whether we are the host based on our hostname
    # and IP address.
    else: # 否则动态看看本机是不是host
        is_host = _matches_machine_hostname(host) 

    # The timeout
    read_timeout = cast(int, params.get_as_int("read_timeout", 60))
    if read_timeout <= 0:
        raise ValueError("The read timeout must be a positive integer.")

    # In specific cases we attempt to instantiate the store twice. For details
    # see the explanation in the except clause below.
    for is_server in [is_host, False]:
        try:
            store = TCPStore(  # type: ignore[call-arg]
                host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
            )

            if is_server:
                log.info(
                    f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
                )

            break
        except (ValueError, RuntimeError) as exc:
            # If we heuristically inferred the value of is_host as True and our
            # first attempt to instantiate the TCP store has failed, try it one
            # more time with is_host set to False. As an edge case there can be
            # more than one process that is part of the same rendezvous on this
            # machine and only one of them will eventually host the store.

            if not is_server or cfg_is_host is not None:
                raise RendezvousConnectionError(
                    "The connection to the C10d store has failed. See inner exception for details."
                ) from exc

    return store
3.5.3.2 C10dRendezvousBackend

可以看到,C10dRendezvousBackend 其核心就是一个 Store,用来存储相关信息,以下代码进行了精简,是通过 set_state 和 get_state 来对 store 进行读写。

代码语言:javascript复制
class C10dRendezvousBackend(RendezvousBackend):
    """Represents a C10d-backed rendezvous backend.

    Args:
        store:
            The :py:class:`torch.distributed.Store` instance to use to
            communicate with the C10d store.
        run_id:
            The run id of the rendezvous.
    """

    # See the explanation in the __init__ method.
    _NULL_SENTINEL = "Y2FuaW1hZGFt"

    _store: Store
    _key: str

    def __init__(self, store: Store, run_id: str) -> None:
        if not run_id:
            raise ValueError("The run id must be a non-empty string.")

        self._store = store
        self._key = "torch.rendezvous."   run_id

        # The read operation of a store blocks the caller until the specified
        # key becomes available. This behavior makes it tricky to use a store
        # as a regular key-value dictionary.
        #
        # As a workaround we initially set a sentinel value as the rendezvous
        # state. Whenever this value gets returned we treat it as a None.
        self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)

    @property
    def name(self) -> str:
        """See base class."""
        return "c10d"

    def get_state(self) -> Optional[Tuple[bytes, Token]]:
        """See base class."""
        # 从store读取数据
        base64_state: bytes = self._call_store("get", self._key)
        return self._decode_state(base64_state)

    def set_state(
        self, state: bytes, token: Optional[Token] = None
    ) -> Optional[Tuple[bytes, Token, bool]]:
        """See base class."""
        base64_state_str: str = b64encode(state).decode()

        if token:
            # Shortcut if we know for sure that the token is not valid.
            if not isinstance(token, bytes):
                result = self.get_state()
                if result is not None:
                    tmp = *result, False
                    # Python 3.6 does not support tuple unpacking in return
                    # statements.
                    return tmp
                return None

            token = token.decode()
        else:
            token = self._NULL_SENTINEL

        # 往 store 之中插入数据    
        base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)

        state_token_pair = self._decode_state(base64_state)
        if state_token_pair is None:
            return None

        new_state, new_token = state_token_pair

        # C10d Store's compare_set method does not offer an easy way to find out
        # whether our write attempt was successful. As a brute-force solution we
        # perform a bitwise comparison of our local state and the remote state.
        return new_state, new_token, new_state == state

    def _call_store(self, store_op: str, *args, **kwargs) -> Any:
        return getattr(self._store, store_op)(*args, **kwargs)

    def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
        if base64_state == self._NULL_SENTINEL.encode():
            return None
        state = b64decode(base64_state)
        return state, base64_state

3.6 StateHolder

3.6.1 _RendezvousStateHolder

这个类的作用是保存与其他节点同步的rendezvous状态,但是需要一个派生类来完成功能。

代码语言:javascript复制
class _RendezvousStateHolder(ABC):
    """Holds the shared rendezvous state synced with other nodes."""

    @property
    @abstractmethod
    def state(self) -> _RendezvousState:
        """Gets the local state."""

    @abstractmethod
    def sync(self) -> Optional[bool]:
        """Reads or writes the latest state.

        Returns:
            A boolean value indicating whether the local state, in case marked
            as dirty, was successfully synced with other nodes.
        """

    @abstractmethod
    def mark_dirty(self) -> None:
        """Marks the local state as dirty."""
3.6.2 _BackendRendezvousStateHolder

_BackendRendezvousStateHolder 拓展了_RendezvousStateHolder。其 sync 就是调用内部的 后端,对 store 进行读写。

代码语言:javascript复制
class _BackendRendezvousStateHolder(_RendezvousStateHolder):
    """Holds the rendezvous state synced with other nodes via a backend.

    Args:
        backend:
            The rendezvous backend to use.
        settings:
            The rendezvous settings.
        cache_duration:
            The amount of time, in seconds, to cache the last rendezvous state
            before requesting it from the backend again.
    """

    _backend: RendezvousBackend
    _state: _RendezvousState
    _settings: RendezvousSettings
    _cache_duration: int
    _token: Token
    _dirty: bool
    _last_sync_time: float
    _dead_nodes: List[_NodeDesc]

    def __init__(
        self, backend: RendezvousBackend, settings: RendezvousSettings, cache_duration: int = 1
    ) -> None:
        self._backend = backend
        self._state = _RendezvousState()
        self._settings = settings
        self._cache_duration = cache_duration
        self._token = None
        self._dirty = False
        self._last_sync_time = -1
        self._dead_nodes = []

    @property
    def state(self) -> _RendezvousState:
        """See base class."""
        return self._state

    def sync(self) -> Optional[bool]:
        """See base class."""
        state_bits: Optional[bytes] = None
        token = None
        has_set: Optional[bool]

        if self._dirty:
            has_set = False

            state_bits = pickle.dumps(self._state)

            # 这里会对后端进行设置
            set_response = self._backend.set_state(state_bits, self._token)
            if set_response is not None:
                state_bits, token, has_set = set_response
        else:
            has_set = None

            if self._cache_duration > 0:
                # Avoid overloading the backend if we are asked to retrieve the
                # state repeatedly. Try to serve the cached state.
                if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0):
                    return None

            get_response = self._backend.get_state()
            if get_response is not None:
                state_bits, token = get_response

        if state_bits is not None:
            try:
                self._state = pickle.loads(state_bits)
            except pickle.PickleError as exc:
                raise RendezvousStateError(
                    "The rendezvous state is corrupt. See inner exception for details."
                ) from exc
        else:
            self._state = _RendezvousState()

        if has_set and self._dead_nodes and log.isEnabledFor(logging.DEBUG):
            node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)

        self._token = token
        self._dirty = False
        self._last_sync_time = time.monotonic()
        self._sanitize()

        return has_set

    def _sanitize(self) -> None:
        expire_time = datetime.utcnow() - (
            self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
        )

        # Filter out the dead nodes.
        self._dead_nodes = [
            node
            for node, last_heartbeat in self._state.last_heartbeats.items()
            if last_heartbeat < expire_time
        ]

        for dead_node in self._dead_nodes:
            del self._state.last_heartbeats[dead_node]

            try:
                del self._state.participants[dead_node]
            except KeyError:
                pass

            try:
                self._state.wait_list.remove(dead_node)
            except KeyError:
                pass

    def mark_dirty(self) -> None:
        """See base class.

        If the local rendezvous state is dirty, the next sync call will try to
        write the changes back to the backend. However this attempt might fail
        if another node, which had the same state, also made changes and wrote
        them before us.
        """
        self._dirty = True
3.6.3 如何使用

StateHolder 具体如何使用在 _DistributedRendezvousOpExecutor 之中有(以下代码精简):

  • 通过 _state_holder.sync() 同步各种状态,因为最新状态在 rendezvous。
  • 通过 self._state_holder.state 得到最新的状态。
  • 进行业务处理。
  • 通过 _state_holder.mark_dirty() 再次同步,把自己状态同步给其他节点
代码语言:javascript复制
def run(
    self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
) -> None:
    """See base class."""
    action = None

    while action != _Action.FINISH:
        # Reads or writes the latest rendezvous state shared by all nodes in
        # the rendezvous. Note that our local changes might get overridden
        # by another node if that node synced its changes before us.
        
        has_set = self._state_holder.sync()  # 这里要同步各种状态,因为最新状态在 rendezvous。

        self._state = self._state_holder.state # 得到最新的状态
        ctx = _RendezvousContext(self._node, self._state, self._settings)

        # Determine the next action to take based on the current state of
        # the rendezvous.
        action = state_handler(ctx, deadline) 

        # 省略部分代码

        if action == _Action.SYNC:
            # Delay the execution by one second to avoid overloading the
            # backend if we are asked to poll for state changes.
            _delay(seconds=1)
        else:
            if action == _Action.KEEP_ALIVE:
                self._keep_alive()
            elif action == _Action.ADD_TO_PARTICIPANTS:
                self._add_to_participants()
            elif action == _Action.ADD_TO_WAIT_LIST:
                self._add_to_wait_list()
            elif action == _Action.REMOVE_FROM_PARTICIPANTS:
                self._remove_from_participants()
            elif action == _Action.REMOVE_FROM_WAIT_LIST:
                self._remove_from_wait_list()
            elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
                self._mark_rendezvous_complete()
            elif action == _Action.MARK_RENDEZVOUS_CLOSED:
                self._mark_rendezvous_closed()

            # Attempt to sync our changes back to other nodes.
            self._state_holder.mark_dirty() # 再次同步,把自己状态同步给其他节点

3.7 总结

我们把目前逻辑总结如下,两个 _BackendRendezvousStateHolder 通过 TCPStore 进行信息交互。

代码语言:javascript复制
                                                                        
 -------------------------------                                       |                                         ------------------------------- 
| _BackendRendezvousStateHolder |                                      |                                        | _BackendRendezvousStateHolder |
|                               |      -------------------             |            --------------------        |                               |
|             _settings  -----------> | RendezvousSettings|            |           | RendezvousSettings | <----------  _settings                |
|                               |      -------------------             |            --------------------        |                               |
|                               |      -------------------             |            --------------------        |                               |
|             _state  --------------> | _RendezvousState  |            |           | _RendezvousState   | <----------  _state                   |
|                               |     |                   |            |           |                    |       |                               |
|                               |      -------------------             |            --------------------        |                               |
|                               |                                      |                                        |                               |
|                               |      -----------------------                      ----------------------      |                               |
|             _backend  ------------> | C10dRendezvousBackend |                    | C10dRendezvousBackend| <-------   _backend                 |
|                               |     |                       |     ---------      |                      |     |                               |
|                               |     |             _store  -----> |TCPStore | <---------  _store         |     |                               |
|                               |     |                       |    |         |     |                      |     |                               |
|                               |      -----------------------      ---------       ----------------------      |                               |
|                               |                                                                               |                               |
|                               |         ^                                                    ^                |                               |
|                               |         |                            |                       |                |                               |
|                               |         |                            |                       |                |                               |
|             sync  ----------------------                             |                        ---------------------   sync                    |
|                               |   set_state                          |                         set_state      |                               |
 -------------------------------                                                                                 ------------------------------- 

手机如下:

0x04 动态逻辑

4.1 入口

我们先看看如何使用 Rendezvous。

launch_agent 启动了一个 LocalElasticAgent,调用了其 run 方法。在调用 run 之前,会生成 rdzv_handler,然后设置到 WorkerSpec 之中。

代码语言:javascript复制
import torch.distributed.elastic.rendezvous.registry as rdzv_registry

@record
def launch_agent(
    config: LaunchConfig,
    entrypoint: Union[Callable, str, None],
    args: List[Any],
) -> Dict[int, Any]:

    rdzv_parameters = RendezvousParameters(
        backend=config.rdzv_backend,
        endpoint=config.rdzv_endpoint,
        run_id=config.run_id,
        min_nodes=config.min_nodes,
        max_nodes=config.max_nodes,
        **config.rdzv_configs,
    )

    # 构建了 rdzv_handler
    rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)

    try:
        spec = WorkerSpec(
            role=config.role,
            local_world_size=config.nproc_per_node,
            entrypoint=entrypoint,
            args=tuple(args),
            rdzv_handler=rdzv_handler, # 这里设置了 rdzv_handler
            max_restarts=config.max_restarts,
            monitor_interval=config.monitor_interval,
            redirects=config.redirects,
            tee=config.tee,
            master_addr=master_addr,
            master_port=master_port,
        )

        agent = LocalElasticAgent( # 构建
            spec=spec, start_method=config.start_method, log_dir=config.log_dir
        )

        result = agent.run() # 启动代理
    except ChildFailedError:

run 函数中,最终会调用到 self._rendezvous(worker_group),_rendezvous 方法会 调用 next_rendezvous() 来处理成员关系变化。

代码语言:javascript复制
    @prof
    def _rendezvous(self, worker_group: WorkerGroup) -> None:
        r"""
        Runs rendezvous for the workers specified by worker spec.
        Assigns workers a new global rank and world size.
        Updates the rendezvous store for the worker group.
        """

        spec = worker_group.spec
        store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
        
				# 省略后续代码

在这个流程之中,rdzv_registry.get_rendezvous_handler(rdzv_parameters) 是最初的来源,因此,我们要看看 get_rendezvous_handler。而 get_rendezvous_handler 会返回 RendezvousHandler,所以 RendezvousHandler 和 rendezvous_handler_registry 才是根本

代码语言:javascript复制
from .api import rendezvous_handler_registry as handler_registry

def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
    """
    This method is used to obtain a reference to a :py:class`RendezvousHandler`.
    Custom rendezvous handlers can be registered by

    ::

      from torch.distributed.elastid.rendezvous import rendezvous_handler_registry
      from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler

      def create_my_rdzv(params: RendezvousParameters):
        return MyCustomRdzv(params)

      rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)

      my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
    """
    return handler_registry.create_handler(params)

我们接下来就分别看看 RendezvousHandler 和 rendezvous_handler_registry。

4.2 基类 RendezvousHandler

RendezvousHandler 用来执行业务逻辑,几个虚函数是:

  • next_rendezvous :rendezvous barrier 的主要入口,新加入的节点会等待在这里,直到当前rendezvous结束,或者超时,或者当前 rendezvous 被标识为closed。
  • is_closed :是否已经结束,如果rendezvous结束,意味着所有试图re-rendezvous都将失败。
  • num_nodes_waiting :返回在rendezvous barrier等待的当前阶段数目,这些节点不属于当前工作组。用户应该周期调用这个方法来检查是否有新节点等待加入工作组,如果有,就调用next_rendezvous() (re-rendezvous。) 进行下一次re-rendezvous。

具体代码如下:

代码语言:javascript复制
class RendezvousHandler(ABC):
    """Main rendezvous interface.

    Note:
        Distributed Torch users normally **do not** need to implement their own
        ``RendezvousHandler``. An implementation based on C10d Store is already
        provided, and is recommended for most users.
    """

    # 获取 rendezvous backend名字
    @abstractmethod
    def get_backend(self) -> str:
        """Returns the name of the rendezvous backend."""

    # rendezvous barrier 的主要入口,新加入的节点会等待在这里,直到当前rendezvous结束,或者超时,或者当前 rendezvous 被标识为closed。
    @abstractmethod
    def next_rendezvous(
        self,
    ) -> Tuple[Store, int, int]:
        """Main entry-point into the rendezvous barrier.

        Blocks until the rendezvous is complete and the current process is
        included in the formed worker group, or a timeout occurs, or the
        rendezvous was marked closed.

        Returns:
            A tuple of :py:class:`torch.distributed.Store`, ``rank``, and
            ``world size``.
        """

    # 是否已经结束,如果rendezvous结束,意味着所有试图re-rendezvous都将失败
    @abstractmethod
    def is_closed(self) -> bool:
        """Checks whether the rendezvous has been closed.

        A closed rendezvous means all future attempts to re-rendezvous within
        same job will fail.
        """

    @abstractmethod
    def set_closed(self):
        """Marks the rendezvous as closed."""

    # 返回在rendezvous barrier等待的当前阶段数目,这些节点不属于当前工作组。用户应该周期调用这个方法来检查是否有新节点等待加入工作组,如果有,就调用`next_rendezvous()` (re-rendezvous。) 进行下一次re-rendezvous。
    @abstractmethod
    def num_nodes_waiting(self) -> int:
        """Returns the number of nodes who arrived late at the rendezvous
        barrier, hence were not included in the current worker group.

        Callers should periodically call this method to check whether new
        nodes are waiting to join the job and if so admit them by calling
        :py:meth:`next_rendezvous()` (re-rendezvous).
        """

    @abstractmethod
    def get_run_id(self) -> str:
        """Returns the run id of the rendezvous.

        The run id is a user-defined id that uniquely identifies an instance of
        a distributed application. It typically maps to a job id and is used to
        allow nodes to join the correct distributed application.
        """

    def shutdown(self) -> bool:
        """Closes all resources that were open for the rendezvous.
        """

4.3 注册

我们接下来看看 rendezvous_handler_registry。

在 torch/distributed/elastic/rendezvous/api.py 之中有如下代码。

代码语言:javascript复制
# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()

所以我们来到了 RendezvousHandlerRegistry。

4.3.1 RendezvousHandlerRegistry

RendezvousHandlerRegistry 是一个负责创建 RendezvousHandler 的工厂类。

  • register 就是往内部字典添加对应的构建器。
  • create_handler 就是依据key,取出对应的构建器。
  • rendezvous_handler_registry 是全局Registry。
代码语言:javascript复制
class RendezvousHandlerRegistry:
    """Represents a registry of :py:class:`RendezvousHandler` backends."""

    _registry: Dict[str, RendezvousHandlerCreator]

    def __init__(self) -> None:
        self._registry = {}

    def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
        """Registers a new rendezvous backend.

        Args:
            backend:
                The name of the backend.
            creater:
                The callback to invoke to construct the
                :py:class:`RendezvousHandler`.
        """
        current_creator: Optional[RendezvousHandlerCreator]
        current_creator = self._registry[backend]
        self._registry[backend] = creator

    def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
        """Creates a new :py:class:`RendezvousHandler`."""

        creator = self._registry[params.backend]
        handler = creator(params)
        return handler
4.3.2 全局registry

系统会创建一个全局的 registry,就是前面看到的 rendezvous_handler_registry。

代码语言:javascript复制
# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()

在这里注册了若干handler,用来提供创建器。rendezvous 提供了如下实现,分别是 etcdetcd-v2c10dstatic

代码语言:javascript复制
from .api import rendezvous_handler_registry as handler_registry

def _register_default_handlers() -> None:
    handler_registry.register("etcd", _create_etcd_handler)
    handler_registry.register("etcd-v2", _create_etcd_v2_handler)
    handler_registry.register("c10d", _create_c10d_handler)
    handler_registry.register("static", _create_static_handler)

运行时候就是:

代码语言:javascript复制
rendezvous_handler_registry = 
  _registry = {dict: 4} 
   'etcd' = {function} <function _create_etcd_handler at 0x7ff657e12d08>
   'etcd-v2' = {function} <function _create_etcd_v2_handler at 0x7ff657e12d90>
   'c10d' = {function} <function _create_c10d_handler at 0x7ff657e12e18>
   'static' = {function} <function _create_static_handler at 0x7ff657b9d2f0>
   __len__ = {int} 4

其含义就是:_create_etcd_handler 可以创建 etcd 类型的handler,以此类推。

4.4 创建

既然有了创建途径,我们就来看看如何创建。rendezvous 提供了如下实现,分别是 etcdetcd-v2c10dstatic,这里我们以 staticc10d 为例进行说明。

4.4.1 静态 RendezvousHandler

我们使用 _create_static_handler 举例,看看如何创建 static 类型的 handler。

首先从 _create_static_handler 入手。

4.4.1.1 _create_static_handler
代码语言:javascript复制
def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
  
    from . import static_tcp_rendezvous

    return static_tcp_rendezvous.create_rdzv_handler(params)

于是我们来到了torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py。在其中有 create_rdzv_handler 建立了 StaticTCPRendezvous。

代码语言:javascript复制
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
    endpoint = params.endpoint.strip()
    master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
    world_size = params.max_nodes
    rank = cast(int, params.config.get("rank"))
    run_id = params.run_id
    if "timeout" in params.config:
        timeout = int(params.config["timeout"])
    else:
        timeout = _default_timeout_seconds
        
    return StaticTCPRendezvous(
        master_addr, master_port, rank, world_size, run_id, timeout
    )
4.4.1.2 StaticTCPRendezvous 子类

StaticTCPRendezvous 拓展了RendezvousHandler,其定义如下,其最主要逻辑是:在 group_rank = 0 之上建立一个 TCPStore,然后封装成一个PrefixStore。

代码语言:javascript复制
class StaticTCPRendezvous(RendezvousHandler):
    """
    Static rendezvous that is a wrapper around the TCPStore.
    Creates TCPStore based on the input parameters with the
    listener on the agent with group_rank=0
    """

    def __init__(
        self,
        master_addr: str,
        master_port: int,
        rank: int,
        world_size: int,
        run_id: str,
        timeout: int,
    ):
        self.master_addr = master_addr
        self.master_port = master_port
        self.rank = rank
        self.world_size = world_size
        self.run_id = run_id
        self.timeout = datetime.timedelta(seconds=timeout)
        self._store: Optional[Store] = None

    def get_backend(self) -> str:
        return "static"

    def next_rendezvous(self) -> Tuple[Store, int, int]:
        if not self._store:
            is_master = self.rank == 0
            self._store = TCPStore(
                self.master_addr,
                self.master_port,
                self.world_size,
                is_master,
                self.timeout,
            )
        store = PrefixStore(self.run_id, self._store)
        return store, self.rank, self.world_size

关键函数

代码语言:javascript复制
def next_rendezvous(self) -> Tuple[Store, int, int]:
    log.info("Creating TCPStore as the c10d::Store implementation")
    if not self._store:
        is_master = self.rank == 0
        self._store = TCPStore(
            self.master_addr,
            self.master_port,
            self.world_size,
            is_master,
            self.timeout,
        )
    store = PrefixStore(self.run_id, self._store)
    return store, self.rank, self.world_size
4.4.2 动态 RendezvousHandler

我们接下来看看如何构建 DynamicRendezvousHandler。

4.4.2.1 _create_c10d_handler

这里 _create_c10d_handler 会返回一个 DynamicRendezvousHandler。

代码语言:javascript复制
def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
    from .c10d_rendezvous_backend import create_backend

    backend, store = create_backend(params)
    return create_handler(store, backend, params)

这里返回了 DynamicRendezvousHandler。

代码语言:javascript复制
def create_handler(
    store: Store, backend: RendezvousBackend, params: RendezvousParameters
) -> DynamicRendezvousHandler:
    """Creates a new :py:class:`DynamicRendezvousHandler` from the specified
    parameters.

    Args:
        store:
            The C10d store to return as part of the rendezvous.
        backend:
            The backend to use to hold the rendezvous state.

     ------------------- ------------------------------------------------------ 
    | Parameter         | Description                                          |
     =================== ====================================================== 
    | join_timeout      | The total time, in seconds, within which the         |
    |                   | rendezvous is expected to complete. Defaults to 600  |
    |                   | seconds.                                             |
     ------------------- ------------------------------------------------------ 
    | last_call_timeout | An additional wait amount, in seconds, before        |
    |                   | completing the rendezvous once the minimum number of |
    |                   | nodes has been reached. Defaults to 30 seconds.      |
     ------------------- ------------------------------------------------------ 
    | close_timeout     | The time, in seconds, within which the rendezvous is |
    |                   | expected to close after a call to                    |
    |                   | :py:meth:`RendezvousHandler.set_closed` or           |
    |                   | :py:meth:`RendezvousHandler.shutdown`. Defaults to   |
    |                   | 30 seconds.                                          |
     ------------------- ------------------------------------------------------ 
    """
    timeout = RendezvousTimeout(
        _get_timeout(params, "join"),
        _get_timeout(params, "last_call"),
        _get_timeout(params, "close"),
    )

    return DynamicRendezvousHandler.from_backend(
        params.run_id,
        store,
        backend,
        params.min_nodes,
        params.max_nodes,
        timeout,
    )
4.4.2.2 from_backend

from_backend 是具体生成 DynamicRendezvousHandler 的方法,相当于生成器。

其生成了 RendezvousSettings,_BackendRendezvousStateHolder 和 node,然后建立了 DynamicRendezvousHandler。

代码语言:javascript复制
@classmethod
def from_backend(
    cls,
    run_id: str,
    store: Store,
    backend: RendezvousBackend,
    min_nodes: int,
    max_nodes: int,
    timeout: Optional[RendezvousTimeout] = None,
):
    """Creates a new :py:class:`DynamicRendezvousHandler`.

    Args:
        run_id:
            The run id of the rendezvous.
        store:
            The C10d store to return as part of the rendezvous.
        backend:
            The backend to use to hold the rendezvous state.
        min_nodes:
            The minimum number of nodes to admit to the rendezvous.
        max_nodes:
            The maximum number of nodes to admit to the rendezvous.
        timeout:
            The timeout configuration of the rendezvous.
    """
    # We associate each handler instance with a unique node descriptor.
    node = cls._node_desc_generator.generate()

    settings = RendezvousSettings(
        run_id,
        min_nodes,
        max_nodes,
        timeout or RendezvousTimeout(),
        keep_alive_interval=timedelta(seconds=5),
        keep_alive_max_attempt=3,
    )

    state_holder = _BackendRendezvousStateHolder(backend, settings)

    return cls(node, settings, backend.name, store, state_holder)
4.4.2.3 DynamicRendezvousHandler
代码语言:javascript复制
Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler`
class that implements the rendezvous mechanism described above. It is a backend-
agnostic type that expects a particular :py:class:`.RendezvousBackend` instance
to be specified during construction.

Torch distributed users can either implement their own backend type or use one
of the following implementations that come with PyTorch:

DynamicRendezvousHandler 拓展了RendezvousHandler,其定义如下,其最主要逻辑是:在 group_rank = 0 之上建立一个 TCPStore,然后封装成一个 PrefixStore。

最主要的是如下几个成员变量:

  • _BackendRendezvousStateHolder 负责在 Rendezvous 之间协调信息。
  • _DistributedRendezvousOpExecutor 负责具体执行业务。
  • _store 负责保存信息(分布式)。
代码语言:javascript复制
class DynamicRendezvousHandler(RendezvousHandler):
    """Represents a handler that sets up a rendezvous among a set of nodes."""

    # Static
    _node_desc_generator = _NodeDescGenerator()
    _this_node: _NodeDesc
    _settings: RendezvousSettings
    _backend_name: str
    _store: Store
    _state_holder: _RendezvousStateHolder
    _op_executor: _RendezvousOpExecutor
    _heartbeat_lock: threading.Lock
    _keep_alive_timer: Optional[_PeriodicTimer]

    @classmethod
    def from_backend(
        cls,
        run_id: str,
        store: Store,
        backend: RendezvousBackend,
        min_nodes: int,
        max_nodes: int,
        timeout: Optional[RendezvousTimeout] = None,
    ):
        """Creates a new :py:class:`DynamicRendezvousHandler`.

        Args:
            run_id:
                The run id of the rendezvous.
            store:
                The C10d store to return as part of the rendezvous.
            backend:
                The backend to use to hold the rendezvous state.
            min_nodes:
                The minimum number of nodes to admit to the rendezvous.
            max_nodes:
                The maximum number of nodes to admit to the rendezvous.
            timeout:
                The timeout configuration of the rendezvous.
        """
        # We associate each handler instance with a unique node descriptor.
        node = cls._node_desc_generator.generate()

        settings = RendezvousSettings(
            run_id,
            min_nodes,
            max_nodes,
            timeout or RendezvousTimeout(),
            keep_alive_interval=timedelta(seconds=5),
            keep_alive_max_attempt=3,
        )

        state_holder = _BackendRendezvousStateHolder(backend, settings)

        return cls(node, settings, backend.name, store, state_holder)

    def __init__(
        self,
        node: _NodeDesc,
        settings: RendezvousSettings,
        backend_name: str,
        store: Store,
        state_holder: _RendezvousStateHolder,
    ) -> None:
        self._this_node = node
        self._settings = settings
        self._backend_name = backend_name
        self._store = store
        self._state_holder = state_holder
        self._op_executor = _DistributedRendezvousOpExecutor(
            self._this_node, self._state_holder, self._settings
        )
        self._heartbeat_lock = threading.Lock()
        self._keep_alive_timer = None

我们也可以用如下方式直接生成 DynamicRendezvousHandler。

代码语言:javascript复制
 store = TCPStore("localhost")
 backend = C10dRendezvousBackend(store, "my_run_id")
 rdzv_handler = DynamicRendezvousHandler.from_backend(
     run_id="my_run_id",
     store=store,
     backend=backend,
     min_nodes=2,
     max_nodes=4
 )
4.4.2.4 next_rendezvous

这一函数调用会被阻塞,直到 worker 的数量达到了要求。在 worker 被初始化,或者重启的时候,这一函数都会被调用。当函数返回时,不同的 worker group 会得到一个 rank 作为唯一的标示。其内部逻辑是:

  • 先使用_RendezvousExitOp让该node退出。
  • 然后再使用_RendezvousJoinOp把该node重新加入。
  • 最后启动心跳,返回world size,store等,此时所有参与的Node都在participants之中。
代码语言:javascript复制
def next_rendezvous(self) -> Tuple[Store, int, int]:
    """See base class."""

    self._stop_heartbeats()

    # Delay the execution for a small random amount of time if this is our
    # first run. This will slightly skew the rendezvous attempts across the
    # nodes and reduce the load on the backend.
    if self._state_holder.state.round == 0:
        _delay(seconds=(0, 0.3))

    exit_op = _RendezvousExitOp()
    join_op = _RendezvousJoinOp()

    deadline = self._get_deadline(self._settings.timeout.join)

    self._op_executor.run(exit_op, deadline)
    self._op_executor.run(join_op, deadline)

    self._start_heartbeats()

    rank, world_size = self._get_world()
    store = self._get_store()

    return store, rank, world_size # 返回的是 worker group 的rank
4.4.2.5 _get_world

上面代码之中,使用了 _get_world,这里我们再分析一下。rank, world_size 这两个变量是动态生成的,所以从 state 之中取出。而且,因为 participants 是在所有Node之间同步的,所以每个Node得到的 participants 完全一致。

代码语言:javascript复制
rank, world_size = self._get_world()
    
def _get_world(self) -> Tuple[int, int]:
	state = self._state_holder.state
	return state.participants[self._this_node], len(state.participants)

state.participants 从哪里来?在 rendezvous 结束时候,会设置 rank。因为每个节点上都是按照同样算法排序,所以rank 排序在每个节点上都是一样的。可以保证每个Node得到的rank是与其他Node不同的。

代码语言:javascript复制
def _mark_rendezvous_complete(self) -> None:
    state = self._state
    state.complete = True
    state.deadline = None

    # Assign the ranks.
    for rank, node in enumerate(sorted(state.participants)):
        state.participants[node] = rank

4.5 容错

前面提到:在开始join rendezvous 和 rendezvous 完成之间,如果有进程崩溃(或网络故障等),就会自动引发一个re-rendezvous,剩余健康节点会自动重组。

代码语言:javascript复制
Torch Distributed Elastic rendezvous is designed to tolerate node failures during the rendezvous process. Should a process crash (or lose network connectivity, etc), between joining the rendezvous and it being completed, then a re-rendezvous with remaining healthy nodes will happen automatically.
4.5.1 ETCD

这部分容错机制在EtcdRendezvousHandler 之中体现的特别明显。

next_rendezvous 方法会调用 rendezvous_barrier。

代码语言:javascript复制
def next_rendezvous(self):
    rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier()

    log.info("Creating EtcdStore as the c10d::Store implementation")
    store = self._rdzv_impl.setup_kv_store(rdzv_version)

    return store, rank, world_size

在 rendezvous_barrier 之中,如果底层抛出各种异常,则会捕获,然后调用 init_phase 再次执行一次rendezvous,直到deadline时间到为止。

代码语言:javascript复制
def rendezvous_barrier(self):
    """
    Main entry point for next rendezvous.
    This method is blocking until rendezvous succeeds or a timeout occurs.

    Returns:
         ``(rdzv_version, rank, world_size)``

    Raises:
        RendezvousTimeoutError - timeout waiting for rendezvous
        RendezvousClosedError - rendezvous is or was closed while waiting
        RendezvousError - other persistent errors that
         render the rendezvous non-retryable
    """
    self._rendezvous_deadline = time.time()   self._timeout
    while True:
        if time.time() > self._rendezvous_deadline:
            raise RendezvousTimeoutError()

        log.info("Attempting to join next rendezvous")
        try:
            # Dis-own our lease in the previous rendezvous, if exists
            if self._lease_this_rank_stop is not None:
                self._lease_this_rank_stop.set()

            return self.init_phase()

        except EtcdRendezvousRetryImmediately:
            # The type of failure suggests we can retry without delay
            pass

        except EtcdRendezvousRetryableFailure:
            # In case of retryable failure, wait a small delay
            # to avoid spamming etcd
            time.sleep(1)

        except RendezvousTimeoutError:
            log.info("Rendezvous timeout occured in EtcdRendezvousHandler")
            raise

        except RendezvousClosedError:
            log.info(
                f"Rendezvous for run_id={self._run_id} was observed to be closed"
            )
            raise

        except RendezvousError:
            raise

        except Exception as e:
            # In case of a general exception, wait a small delay
            # to avoid spamming etcd
            # FIXME: there are a few things that fall under this like
            # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly.
            log.info("Rendezvous attempt failed, will retry. Reason: "   str(e))
            time.sleep(1)

init_phase 会发起一轮 rendezvous。

代码语言:javascript复制
def init_phase(self):
    """
    Initially, the rendezvous state is expected to be one of:

    1. empty (non-existent) - in this case we try to create a new one.
    2. joinable - we try to join it.
    3. final - we announce ourselves as waiting, and go into monitoring mode

    Any other state is considered transitional, and will be retried after
    a short delay.

    Returns:
        ``(rdzv_version, rank, world_size)``

    Raises:
        RendezvousClosedError - current rendezvous was/is closed
        EtcdRendezvousRetryableFailure - observed some intermediate
         state, which is best handled by retrying later
    """
    try:
        active_version = self.try_create_rendezvous() # 发起一轮rendezvous
        state = json.loads(active_version.value)
        log.info("New rendezvous state created: "   str(state))
    except etcd.EtcdAlreadyExist:
        active_version, state = self.get_rdzv_state()
        # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound),
        # but this is ok for us - just means we'll restart from beginning.
        log.info("Observed existing rendezvous state: "   str(state))

    if state["status"] == "closed":
        raise RendezvousClosedError()

    if state["status"] == "joinable":
        return self.join_phase(state["version"])

    if state["status"] == "final":
        self.handle_existing_rendezvous(state["version"])
        raise EtcdRendezvousRetryImmediately()

    self.try_wait_for_state_change(etcd_index=active_version.etcd_index   1)
    raise EtcdRendezvousRetryableFailure()
4.5.2 DynamicRendezvousHandler

DynamicRendezvousHandler 之中就体现的不明显,应该是因为 DynamicRendezvousHandler 是在ETCD之后开发,所以很多功能不完善,在演进之中。

本系列是基于PyTorch 1.9 为主进行分析,所以上面 next_rendezvous 代码之中没有错误处理,直接抛到最外面去了。在2021-12月最新代码之中,已经加入了错误处理,后续应该还会继续完善。

代码语言:javascript复制
def next_rendezvous(self) -> Tuple[Store, int, int]:
    """See base class."""
    msg = (
        f"The node '{self._this_node}' attempts to join the next round of the rendezvous "
        f"'{self._settings.run_id}'."
    )
    self._record(message=msg)

    try: # 加入了错误处理
        self._stop_heartbeats()

        # Delay the execution for a small random amount of time if this is our
        # first run. This will slightly skew the rendezvous attempts across the
        # nodes and reduce the load on the backend.
        if self._state_holder.state.round == 0:
            _delay(seconds=(0, 0.3))

        exit_op = _RendezvousExitOp()
        join_op = _RendezvousJoinOp()

        deadline = self._get_deadline(self._settings.timeout.join)

        self._op_executor.run(exit_op, deadline)
        self._op_executor.run(join_op, deadline)

        self._start_heartbeats()

        rank, world_size = self._get_world()
        store = self._get_store()

    except Exception as e: # 加入了错误处理,但是没有发起下一轮rendezvous
        self._record(
            message=f"{type(e).__name__}: {str(e)}",
            node_state=NodeState.FAILED,
        )
        raise

    msg = (
        f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of "
        f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size "
        f"{world_size}."
    )
    self._record(message=msg, rank=rank)

    return store, rank, world_size

4.6 小结

Rendezvous 和 Agent 之间的逻辑联系总结如下,每个启动脚本都有这么一套机制。若干启动脚本的机制之间会互相联系。

代码语言:javascript复制
 -----------------------------        ------------------------------------------------ 
| LocalElasticAgent           |      | WorkerSpec                                     |
|                             |      |                                                |
|  ------------------------   |      |   rdzv_handler = {DynamicRendezvousHandler} ------- 
| |WorkerGroup             |  |      |                                                |   |
| |            spec  --------------> |   entry = worker_fn                            |   |
| |            workers     |  |      |                                                |   |
| |            store       |  |      |   role = {str} 'trainer'                       |   |
| |            group_rank  |  |      |                                                |   |
| |       group_world_size |  |       ------------------------------------------------    |
| |                        |  |                                                           |
|  ------------------------   |                                                           |
|                             |                                                           |
| rdzv_run_id                 |                                                           |
| store                       |             -----------------------------------------     |
|                             |            |DynamicRendezvousHandler                 |    |
 -----------------------------             |                                         |    |
                                           |                                         |    |
                                           |   _settings: RendezvousSettings         | <-- 
                                           |                                         |
                                           |   _store: Store                         |
                                           |                                         |
                                           |   _state_holder: _RendezvousStateHolder |
                                           |                                         |
                                           |   _op_executor: _RendezvousOpExecutor   |
                                           |                                         |
                                            ----------------------------------------- 

或者和前面的静态逻辑结合起来看看。

代码语言:javascript复制
 ------------------------     ----------------------------------------------          ------------------------     --------------------------------------------- 
| LocalElasticAgent      |   | WorkerSpec                                   |     |  | LocalElasticAgent      |   | WorkerSpec                                  |
|                        |   |                                              |     |  |                        |   |                                             |
|  --------------------  |   |   rdzv_handler = {DynamicRendezvousHandler} ----   |  |  --------------------  |   |  rdzv_handler = {DynamicRendezvousHandler}----- 
| | WorkerGroup        | |   |                                              |  |  |  | | WorkerGroup        | |   |                                             |  |
| |        spec  ----------->    entry = worker_fn                          |  |  |  | |        spec  ----------->   entry = worker_fn                          |  |
| |        workers     | |   |                                              |  |  |  | |        workers     | |   |                                             |  |
| |        store       | |   |   role = {str} 'trainer'                     |  |  |  | |        store       | |   |  role = {str} 'trainer'                     |  |
| |        group_rank  | |   |                                              |  |  |  | |        group_rank  | |   |                                             |  |
| |   group_world_size | |    ----------------------------------------------   |  |  | |   group_world_size | |    ---------------------------------------------   |
| |                    | |    --------------------------------------------     |  |  | |                    | |    --------------------------------------------    |
|  --------------------  |   | DynamicRendezvousHandler                   |    |  |  |  --------------------  |   | DynamicRendezvousHandler                   |   |
|  rdzv_run_id           |   |                                            |    |  |  |  rdzv_run_id           |   |                                            |   |
|  store                 |   |                                            |    |  |  |  store                 |   |                                            |   |
 ------------------------    |    _settings: RendezvousSettings           |    |  |   ------------------------    |    _settings: RendezvousSettings           |   |
                             |                                            | <--   |                               |                                             <-- 
                             |    _store: Store                           |       |                               |    _store: Store                           |
                             |                                            |       |                               |                                            |
                      ----------  _state_holder: _RendezvousStateHolder   |       |                               |    _state_holder: _RendezvousStateHolder  ----- 
                     |       |                                            |       |                               |                                            |   |
                     |       |    _op_executor: _RendezvousOpExecutor     |       |                               |    _op_executor: _RendezvousOpExecutor     |   |
                     |       |                                            |       |                               |                                            |   |
                     |        --------------------------------------------        |                                --------------------------------------------    |
                     v                                                            |                                                                                |
            --------- ---------------------                                       |                                       -------------------------------          |
           | _BackendRendezvousStateHolder |                                      |                                      | _BackendRendezvousStateHolder |         |
           |                               |      -------------------             |          --------------------        |                               |         |
           |             _settings  -----------> | RendezvousSettings|            |         | RendezvousSettings | <----------  _settings                | <------- 
           |                               |      -------------------             |          --------------------        |                               |
           |                               |      -------------------             |          --------------------        |                               |
           |             _state  --------------> | _RendezvousState  |            |         | _RendezvousState   | <----------  _state                   |
           |                               |     |                   |            |         |                    |       |                               |
           |                               |      -------------------             |          --------------------        |                               |
           |                               |      -----------------------                    ----------------------      |                               |
           |             _backend  ------------> | C10dRendezvousBackend |                  | C10dRendezvousBackend| <-------   _backend                 |
           |                               |     |                       |     ---------    |                      |     |                               |
           |                               |     |             _store  -----> |TCPStore | <-------  _store         |     |                               |
           |                               |      --- -------------------      --- -----     ----------- ----------      |                               |
           |                               |         ^                            |                     ^                |                               |
           |                               |         |                            |                     |                |                               |
           |                               |         |                            |                     |                |                               |
           |             sync  ----------------------                             |                      ---------------------   sync                    |
           |                               |   set_state                  NODE 1  |  NODE 2               set_state      |                               |
            -------------------------------                                                                               ------------------------------- 

手机如下:

0x05 总结

目前我们分析了Rendezvous的静态结构和动态逻辑,大家对其机制有了一个基本理解,比如有如下概念:

  • 节点概念_NodeDesc,这样可以把系统表达出来。
  • 状态概念。_RendezvousState 是rendezvous的状态。是动态信息,每一个 node 都会维护一个本地 state。
  • 总体静态类 _BackendRendezvousStateHolder,用来把节点,状态,后端以及其他信息统一维护起来。
  • 共享共享键值存储,比如TCPStore,可以集中保存上述信息,也可以用来彼此交换信息,达成共识。
  • 动态server或者handler,RendezvousHandler就提供了一套API以供外界访问。

下一篇我们介绍内部业务逻辑如何实现,即 Rendezvous 引擎。

0xFF 参考

云原生的弹性 AI 训练系列之二:PyTorch 1.9.0 弹性分布式训练的设计与实现

PyTorch Elastic源码阅读

0 人点赞