[源码解析] PyTorch 流水线并行实现 (6)--并行计算

2021-10-13 10:16:14 浏览数 (1)

[源码解析] PyTorch 流水线并行实现 (6)--并行计算

目录

  • [源码解析] PyTorch 流水线并行实现 (6)--并行计算
    • 0x00 摘要
    • 0x01 总体架构
      • 1.1 使用
      • 1.2 前向传播
      • 1.3 Pipeline 类
        • 1.3.1 构建依赖
        • 1.3.2 Queue
        • 1.3.3 计算
    • 0x02 并行拷贝和计算
      • 2.1 GPU并行操作
      • 2.2 PyTorch
      • 2.3 Stream 封装
        • 2.3.1 PyTorch 样例
        • 2.3.2 生成/获取
        • 2.3.3 记录
        • 2.3.4 等待
      • 2.4 拷贝API
      • 2.5 等待API
      • 2.6 使用
        • 2.6.1 总体概念
        • 2.6.2 构建拷贝流
        • 2.6.3 并行操作
        • 2.6.4 预先拷贝
        • 2.6.5 计算
    • 0x03 重计算
      • 3.1 解析
      • 3.2 封装API
      • 3.3 实现
        • 3.3.1 Checkpoint
        • 3.3.2 Recompute
      • 3.4 总体调用
    • 0xFF 参考

0x00 摘要

前几篇文章我们介绍了 PyTorch 流水线并行的基本知识,自动平衡机制和切分数据,本文我们结合论文内容来看看如何实现流水线。

流水线并行其他文章链接如下:

[源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

[源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

[源码解析] 深度学习流水线并行 GPipe(3) ----重计算

[源码解析] 深度学习流水线并行之PipeDream(1)--- Profile阶段

[源码解析] 深度学习流水线并行 PipeDream(2)--- 计算分区

[源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型

[源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎

[源码解析] 深度学习流水线并行 PipeDream(5)--- 通信模块

[源码解析] 深度学习流水线并行 PipeDream(6)--- 1F1B策略

[源码解析] PyTorch 流水线并行实现 (1)--基础知识

[源码解析] PyTorch 流水线并行实现 (2)--如何划分模型

[源码解析] PyTorch 流水线并行实现 (3)--切分数据和运行时系统

[源码解析] PyTorch 流水线并行实现 (4)--前向计算

[源码解析] PyTorch 流水线并行实现 (5)--计算依赖

本文图片来自论文和github源码。

0x01 总体架构

我们首先从整体角度来梳理一下 torchgpipe。

1.1 使用

我们使用源码中的测试例子来进行分析。示例中有一个由三个层组成的Sequential模型,被GPipe封装之后,进行前向和后向传播。

代码语言:javascript复制
class Layer1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)

    def forward(self, input):
        yield stash('1to3', input)
        output = self.conv(input)
        return output

class Layer2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)

    def forward(self, input):
        output = self.conv(input)
        return output

class Layer3(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 1)

    def forward(self, input):
        skip_1to3 = yield pop('1to3')
        output = self.conv(input)   skip_1to3
        return output

model = nn.Sequential(Layer1(), Layer2(), Layer3()) # 构建了一个Sequential
model = GPipe(model, balance, chunks=3, checkpoint=checkpoint) #在 Sequential 基础上构建 GPipe

in_device = model.devices[0]
out_device = model.devices[-1]

input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
output = model(input) # 这里将调用到 GPipe.forward
loss = output.mean()
loss.backward() # 这里会进行反向传播

1.2 前向传播

GPipe 的前向传播之中做了如下操作:

  • 利用 scatter 函数把输入分割,就是把 mini-batch 分割为 micro-batches。
  • 利用 _ensure_copy_streams 方法针对每个设备生成新的 CUDA stream。
  • 生成一个 Pipeline,并且运行。
  • 运行结束之后,利用 gather 方法把micro-batches 合并成一个 mini-batch。

因此我们可以看到,对于每次迭代的 forward 操作,都会生成一个 Pipeline 类进行操作,返回给调用者。

代码语言:javascript复制
def forward(self, input: TensorOrTensors) -> TensorOrTensors:  # type: ignore
    """:class:`GPipe` is a fairly transparent module wrapper. It doesn't
    modify the input and output signature of the underlying module. But
    there's type restriction. Input and output have to be a
    :class:`~torch.Tensor` or a tuple of tensors. This restriction is
    applied at partition boundaries too.

    Args:
        input (torch.Tensor or tensors): input mini-batch

    Returns:
        tensor or tensors: output mini-batch

    Raises:
        TypeError: input is not a tensor or tensors.

    """
    microbatch.check(input)

    if not self.devices:
        # Empty sequential module is not illegal.
        return input

    # Divide a mini-batch into micro-batches.
    batches = microbatch.scatter(input, self.chunks)

    # Separate CUDA streams for copy.
    copy_streams = self._ensure_copy_streams()

    # The micro-batch index where the checkpointing stops.
    if self.training:
        checkpoint_stop = {
            'always': self.chunks,
            'except_last': self.chunks-1,
            'never': 0,
        }[self.checkpoint]
    else:
        checkpoint_stop = 0

    # Run pipeline parallelism.
    pipeline = Pipeline(batches,
                        self.partitions,
                        self.devices,
                        copy_streams,
                        self._skip_layout,
                        checkpoint_stop)
    pipeline.run()

    # Merge the micro-batches into one mini-batch.
    output = microbatch.gather(batches)
    return output

_ensure_copy_streams 方法就是针对每个设备生成新的 CUDA stream

代码语言:javascript复制
def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
    """Ensures that :class:`GPipe` caches CUDA streams for copy.

    It's worth to cache CUDA streams although PyTorch already manages a
    pool of pre-allocated CUDA streams, because it may reduce GPU memory
    fragementation when the number of micro-batches is small.

    """
    if not self._copy_streams:
        for device in self.devices:
            self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])

    return self._copy_streams

1.3 Pipeline 类

在 Pipeline 类的 run 方法之中按照时钟周期来启动计算,这样在前向传播之中,就按照这个序列,像水波纹一样扩散。

代码语言:javascript复制
def run(self) -> None:
    """Runs pipeline parallelism.

    It modifies the given batches in place.

    """
    batches = self.batches
    partitions = self.partitions
    devices = self.devices
    skip_layout = self.skip_layout

    m = len(batches)
    n = len(partitions)

    skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]

    with spawn_workers(devices) as (in_queues, out_queues):
        for schedule in clock_cycles(m, n): # 这里使用,给出了执行序列计划,后续按照这个来执行
            self.fence(schedule, skip_trackers) # 拷贝,设定依赖
            self.compute(schedule, skip_trackers, in_queues, out_queues) # 启动各种Task
1.3.1 构建依赖

在 Pipeline 之中,fence 方法(省略部分代码)利用 depend 来构建后向传播的依赖关系。

代码语言:javascript复制
    def fence(self,
              schedule: List[Tuple[int, int]],
              skip_trackers: List[SkipTrackerThroughPotals],
              ) -> None:
        """Copies micro-batches after computation for the previous
        micro-batches.
        """
        batches = self.batches
        copy_streams = self.copy_streams
        skip_layout = self.skip_layout

        for i, j in schedule:
            # Ensure that batches[i-1] is executed after batches[i] in
            # backpropagation by an explicit dependency.
            if i != 0:
                depend(batches[i-1], batches[i]) # 在这里建立了后向传播依赖关系
                
            next_stream = copy_streams[j][i]

            for prev_j, ns, name in skip_layout.copy_policy(j):
                prev_stream = copy_streams[prev_j][i]
                skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)

            if j != 0:
                prev_stream = copy_streams[j-1][i]
                # 从之前的micro-batches进行拷贝
                copy(batches[i], prev_stream, next_stream)                
1.3.2 Queue

Worker 和主线程之间使用了 Python 的 Queue 数据结构进行交互。Queue 类实现了一个基本的先进先出(FIFO)容器,使用 put() 将元素添加到序列尾端,get() 从队列尾部移除元素。

代码语言:javascript复制
A multi-producer, multi-consumer queue.

两个关键函数是:

  • get([block, [timeout]]) 读队列,timeout为等待时间,如果队列满,则阻塞。
  • put(item, [block, [timeout]]) 写队列,timeout为等待时间,如果队列空,则阻塞。
1.3.3 计算

具体训练是通过 compute 函数完成。

代码语言:javascript复制
def compute(self,
            schedule: List[Tuple[int, int]],
            skip_trackers: List[SkipTrackerThroughPotals],
            in_queues: List[InQueue],
            out_queues: List[OutQueue],
            ) -> None:
    """Runs tasks with synchronization to copy streams."""
    batches = self.batches
    partitions = self.partitions
    devices = self.devices
    copy_streams = self.copy_streams
    checkpoint_stop = self.checkpoint_stop

    n = len(partitions)
    streams = [current_stream(d) for d in devices]
    exc_info: Optional[ExcInfo] = None

    # With checkpointing, the autograd graph looks like this diagram:
    # ┌─────┸──────┐
    # │    Copy    │
    # └─────┰──────┘   (fence)
    # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
    #       ┃          (compute)
    # ┌─────┸──────┐
    # │    Wait    │ [1] Synchronize the current stream with the copy stream.
    # └─────┰──────┘
    # ┌─────┸──────┐
    # │ Checkpoint │ [2] Compute a partition within checkpointing.
    # └─────┰──────┘
    # ┌─────┸──────┐
    # │    Wait    │ [3] Synchronize the copy stream with the current stream.
    # └─────┰──────┘
    #       ┠ ─ ─ ─ ┐
    #       ┃ ┌─────┴─────┐
    #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
    #       ┃ └─────┬─────┘
    #       ┠ ─ ─ ─ ┘
    #       ┃
    # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
    # ┌─────┸──────┐   (fence)
    # │    Copy    │
    # └─────┰──────┘
    for i, j in schedule:
        batch = batches[i]
        partition = partitions[j]

        # Synchronize with the copied input. ([1] in the diagram)
        if j != 0: # 等待拷贝结束
            wait(batch, copy_streams[j][i], streams[j])

        # Determine whether checkpointing or not.
        checkpoint = (i < checkpoint_stop)
        if checkpoint:
            def function(input: TensorOrTensors,
                         partition: nn.Sequential = partition,
                         skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                         ) -> TensorOrTensors:
                with use_skip_tracker(skip_tracker):
                    return partition(input)

            chk = Checkpointing(function, batch)
            task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
            del function, chk

        else:
            def compute(batch: Batch = batch,
                        partition: nn.Sequential = partition,
                        skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                        ) -> Batch:
                with use_skip_tracker(skip_tracker):
                    return batch.call(partition)

            task = Task(streams[j], compute=compute, finalize=None)
            del compute

        # Compute tasks in parallel. ([2] in the diagram)
        in_queues[j].put(task) # 并行执行操作

    for i, j in schedule:
        # 等待运行结果
        ok, payload = out_queues[j].get()

        # Hold the first exception.
        if exc_info is not None:
            continue
        elif not ok:
            exc_info = cast(ExcInfo, payload)
            continue

        task, batch = cast(Tuple[Task, Batch], payload)

        # The copy stream synchronizes to copy the output. ([3] in the
        # diagram)
        if j != n-1: # 拷贝输出
            wait(batch, streams[j], copy_streams[j][i])

        # Finalize tasks. If checkpointing is enabled, here the
        # recomputation is scheduled at backpropagation. ([4] in the
        # diagram)
        with use_device(devices[j]):
            task.finalize(batch)

        batches[i] = batch

    # Fail at the first exception.
    if exc_info is not None:
        raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

我们总结梳理一下大致业务逻辑(并行逻辑):

  1. 系统调用 spawn_workers 来生成若干 workers。
    1. spawn_workers 为每个 device 生成了一个 Thread,这个 Thread 的执行函数是 worker。spawn_workers 内部也会针对每一个device生成一个 in_queue, out_queue。所以可保证每个device之上是串行来执行业务操作。
    2. 这些 queues 被添加到 (in_queues, out_queues) 之中。然后把 (in_queues, out_queues) 返回给 Pipeline 主线程。之后就是使用 (in_queues, out_queues) 作为各个task 之间传递信息的上下文。
  2. Pipeline 主线程得到 (in_queues, out_queues) 之后,使用clock_cycles 算法生成一系列迭代,每个迭代是一个schedule。
  3. 对于每个迭代(schedule),先用fence来进行拷贝stream & 设定依赖,然后使用 compute 来进行训练。这就顺序启动了多个 compute
  4. 在每个 compute 之中,遍历这个 schedule,对于其中 (i, j) 运行一个Task,即找到其device对应的in_queue,把Task插进去。
  5. Worker Thread 阻塞在 in_queue 之上,如果发现有内容,就读取 Task,运行。虽然多个 compute 是顺序执行,但是因为compute 只是一个插入queue操作,可以立即返回。而多个 worker Thread 阻塞在 queue 之上,这之后是可以并行训练的
  6. Worker Thread 把运行结果插入到 out_queue之中。
  7. compute 方法会取出 out_queue 之中的运行结果,进行后续处理。

具体如下图。

代码语言:javascript复制
          -------------------------------------------------------------------         ----------------------------------------- 
         | Pipeline                                                          |  1    | spawn_workers                           |
         |                                     spawn_workers(devices)   -----------> |                                         |
         |                                                                   |       |  -------------------------------------  |
         |               for schedule in clock_cycles(m, n)                  |       | | workers                             | |
         |                                                                   |       | |                                     | |
         |                     | 2                                           |       | |                                     | |
         |                     |                                             |       | |  device 1 : in_queue 1, out_queue 1 | |
         |                      ----------- ---------------                  |       | |                                     | |
         |                     |           |               |                 |       | |  device 2 : in_queue 2, out_queue 2 | |
         |                     v           v               v                 |       | |                                     | |
         |   ------------------ ------          ----------- --------------   |       | |  device 3 : in_queue 3, out_queue 3 | |
         |  | compute                 |        | compute                  |  |       | |                                     | |
         |  |                         |  3     |                          |  |       | |                                     | |
         |  |  in_queues[j].put(task) |        |   in_queues[j].put(task) |  |       |  -------------------------------------  |
         |  |                         | ...... |                          |  |       |                                         |
         |  |  out_queues[j].get()    |        |   out_queues[j].get()    |  |        ----------------------------------------- 
         |  |                         |        |                          |  |
         |   ---------- --- ----------          ---------------- ---- ----   |
         |             |   ^                                    ^    |       |
         |             |   |                                    |    |       |
          ------------------------------------------------------------------- 
                     7 |   | 4                                7 |    | 4
                       |   |                                    |    |
                       v   |                                    |    v
                  ----- --- ------------------------------------ ---- ----- 
                 |                in_queues        out_queues              |
 ------------>   |                                                         |  <-------------------- 
|                 ----- --------------------------------------------- -----                        |
| 6                    |                                             |                           6 |
|                    5 |                                             | 5                           |
|                      |                                             |                             |
|                      |                                             |                             |
|     -------------------------------------            -------------------------------------       |
|    | Thread 1        |        device 1   |          | Thread 2     |             device 3 |      |
|    |                 |                   |          |              |                      |      |
|    |  ---------------------------------  |          |  ---------------------------------  |      |
|    | | Worker        |                 | |          | | Worker     |                    | |      |
|    | |               v                 | |          | |            v                    | |      |
|    | |  task = in_queue.get()          | |          | |   task = in_queue.get()         | |      |
|    | |                                 | |  ......  | |                                 | |      |
|    | |  batch = task.compute()         | |          | |   batch = task.compute()        | |      |
|    | |                                 | |          | |                                 | |      |
 -------- out_queue.put((task, batch)))  | |          | |   out_queue.put((task, batch)) ---------> 
     | |                                 | |          | |                                 | |
     |  ---------------------------------  |          |  ---------------------------------  |
      -------------------------------------            ------------------------------------- 

手机如下:

0x02 并行拷贝和计算

我们接下来分析并行拷贝和计算(Concurrent Copy and Computation: Streams)。

2.1 GPU并行操作

我们首先看看 GPU 提供的并行操作功能。

CUDA流表示一个GPU操作队列,即某个设备绑定的,按照顺序执的核(kernel)序列。我们可以把一个流看作是GPU之上的一个任务。用户向流的队列上添加一系列操作,GPU会按照添加到流中的先后顺序而依次执行这一系列操作。在同一个流之中,所有操作是串行序列化,因此这些操作永远不会并行。因此,要想并行,两个操作必须位于不同的 stream 中。不同流中的核函数可以交错,甚至可能重叠。

几乎所有具有计算能力1.1及更高计算能力的CUDA设备都支持并发复制和执行,即设备重叠(Device Overlap)功能,其特点如下:

  1. 数据拷贝和数值计算可以并行。
  2. 两个方向的拷贝可以并行(GPU到CPU,CPU到GPU)。
  3. 进行数值计算的kernel不能读写正在拷贝的数据。

因为 CPU 内存一般来说是大于 GPU内存,因此不可能把 CPU 内存一次性都拷贝到GPU,需要分块传输。所以设备重叠功能就能够很好提高GPU程序的执行效率,比如:

  1. 将数据拆分成为许多块,每一块交给一个Stream来处理。
  2. 每一个Stream会进行如下操作:
    1. 将属于该Stream的数据从host内存拷贝进device内存,
    2. GPU进行 kernel 运算,将计算结果保存到GPU内存,
    3. 把 Stream计算结果从device 内存拷贝回host内存。
  3. GPU的scheduler决定 stream 如何并行。
  4. CPU 的操作也可以同时并行。

2.2 PyTorch

除非另有指定,PyTorch将每个绑定到设备的核函数发布到默认流。因为前向传播位于 default stream 中,所以要想并行处理 "下一个 batch 数据的预读取(拷贝CPU到GPU)" 和 "当前 batch 的前向传播",就必须做到:

  • cpu 上的 batch 数据 必须是 pinned。锁页可以使得硬件设备直接访问CPU内存,这样就减少了某些复制操作,锁定的页面不可以被交换到硬盘之上。在GPU上分配的内存默认都是锁页内存。
  • 预读取操作必须在另一个 stream 上进行。

Torchgpipe将每个拷贝核注册到非默认流中,同时将计算核保留在默认流中。这允许设备j处理

F_{i,j}

的同时也会发送

x^j_{i-1}

到设备

j 1

和/或 从设备

j-1

接受

x_i^{j-1}

此外,每个device对每个微批次使用不同的流。由于不同的微批次之间没有真正的依赖关系,因此流的这种使用是安全的,这允许尽可能快地进行拷贝。请参见下图。

图上表示的是设备 j 的时间线,是否使用非默认流进行复制

  • (a)部分的意思是:仅使用默认流,复制核可能会阻塞计算核(反之亦然),直到复制完全完成。
  • (b)部分的意思是:使用复制流,计算可以与从其他设备发送或接收数据同时进行。

2.3 Stream 封装

因为是对stream进行操作,所以 torchgpipe 对底层流操作进行了一些基础封装,流相关主要代码位于:torchgpipe/stream.py。

2.3.1 PyTorch 样例

因为 torchgpipe 用到了 wait_stream 和 record_stream,而网上相关资料较少,如果深入 CUDA 或者 PyTorch 相关部分又容易耗费太多时间,所以我们通过 torch/nn/parallel/distributed.py 中的代码来看看如何使用,可以看到。

  • wait_stream 起到等待作用:一个流等待另一个流完成。
  • record_stream 起到确保作用:保证张量内存在操作完成之前不会被重用。结合其他资料,我们可以理解为确保某一个流上记录的操作完成,才能进行下一步。

具体代码如下:

代码语言:javascript复制
# Perform CPU -> GPU copies in a background stream. This code is
# motivated from similar logic in torch/nn/parallel/_functions.py
stream = _get_stream(target_gpu)
with torch.cuda.stream(stream):
    output = obj.to(target_gpu) # 拷贝
# synchronize with the copy stream
with torch.cuda.device(target_gpu):
    current_stream = torch.cuda.current_stream()
    # Sync the current stream with the copy stream
    current_stream.wait_stream(stream) # 等待
    # Ensure tensor memory is not reused until work on main stream is complete
    output.record_stream(current_stream) # 确保
return (output,)
2.3.2 生成/获取

关于生成和获取的函数为:

  • new_stream 会生成一个新的stream。
  • current_stream 返回当前流。
  • default_stream 返回了缺省流。
代码语言:javascript复制
def new_stream(device: torch.device) -> AbstractStream:
    """Creates a new stream for either CPU or CUDA device."""
    if device.type != 'cuda':
        return CPUStream
    return torch.cuda.Stream(device)

def current_stream(device: torch.device) -> AbstractStream:
    """:func:`torch.cuda.current_stream` for either CPU or CUDA device."""
    if device.type != 'cuda':
        return CPUStream
    return torch.cuda.current_stream(device)

def default_stream(device: torch.device) -> AbstractStream:
    """:func:`torch.cuda.default_stream` for either CPU or CUDA device."""
    if device.type != 'cuda':
        return CPUStream
    return torch.cuda.default_stream(device)
2.3.3 记录

以下方法用来封装了CUDA record_stream。

代码语言:javascript复制
def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
    """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
    if is_cuda(stream):
        # NOTE(sublee): record_stream() on a shifted view tensor throws
        # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
        # protect the tensor against unexpected reallocation, here we use a
        # temporal tensor associated with the same storage without shifting as
        # a workaround.
        #
        # Issue: https://github.com/pytorch/pytorch/issues/27366
        #
        tensor = tensor.new_empty([0]).set_(tensor.storage())

        tensor.record_stream(as_cuda(stream))
2.3.4 等待

以下方法封装了CUDA wait_stream 。

  • 如果两个流都是CUDA流,则就是一个流等待另外一个流完成。
  • 否则采用 synchronize() 来保证 CPU 等待 CUDA 完成。

因为这里流操作是异步的,所以当函数返回时候无法确定操作是否已经完成,所以将CPU和主机进行同步,或者CUDA流之间进行同步,以确保GPU完成流操作。

代码语言:javascript复制
def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
    """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
    makes the source stream wait until the target stream completes work queued.
    """
    if is_cuda(target):
        if is_cuda(source):
            # A CUDA stream waits another CUDA stream.
            as_cuda(source).wait_stream(as_cuda(target))
        else:
            # CPU waits a CUDA stream.
            as_cuda(target).synchronize()

    # If the target is CPU, synchronization is not required.

这里wait_stream和synchronize最终都会完成等待操作,比如synchronize最终调用到了 cudaDeviceSynchronize,该方法将停止CPU端线程的执行,直到GPU端完成此前CUDA上的任务(包括kernel函数、数据拷贝等)。

既然已经把 Stream 操作进行了基础封装,torchgpipe 接下来就使用这些封装函数实现了拷贝操作和等待操作,我们接下来看看。

2.4 拷贝API

拷贝流的 API 如下,其实就是调用了 Copy 这个类的forward方法。

代码语言:javascript复制
def copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
    batch[:] = Copy.apply(prev_stream, next_stream, *batch)

Copy 拓展了torch.autograd.Function,主要就是应用record_stream来协助完成拷贝业务。

代码语言:javascript复制
class Copy(torch.autograd.Function):
    """Copies tensors on specific streams."""
    @staticmethod
    def forward(ctx: Context,  # type: ignore
                prev_stream: AbstractStream,
                next_stream: AbstractStream,
                *input: Tensor,
                ) -> Tensors:
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream

        output = []
        output_stream = current_stream(get_device(next_stream)) # 得到下一个流

        with use_stream(prev_stream), use_stream(next_stream):
            for x in input:
                y = x.to(get_device(next_stream)) # 把 input 拷贝到 next_stream
                output.append(y)

                # 'prev_stream' is not where 'x' has been allocated.
                record_stream(x, prev_stream) # 记录流,确保拷贝完成之前不会使用 x
                # 'y' has been allocated on 'next_stream'.
                # It might be used on the current stream captured as 'output_stream'.
                record_stream(y, output_stream) # 记录流,确保拷贝完成之前不会使用 y

        return tuple(output) # 返回输出

    @staticmethod
    def backward(ctx: Context,
                 *grad_output: Tensor,
                 ) -> Tuple[Optional[Tensor], ...]:
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream

        grad_input: Deque[Tensor] = deque(maxlen=len(grad_output))
        input_stream = current_stream(get_device(prev_stream))

        with use_stream(prev_stream), use_stream(next_stream):
            for x in reversed(grad_output):
                y = x.to(get_device(prev_stream))
                grad_input.appendleft(y)

                # 'next_stream' is not where 'x' has been allocated.
                record_stream(x, next_stream)
                # 'y' has been allocated on 'prev_stream'.
                # It might be used on the current stream captured as 'input_stream'.
                record_stream(y, input_stream)

        grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
        return grad_streams   tuple(grad_input)

2.5 等待API

wait 则是调用了 Wait 类的forward方法。

代码语言:javascript复制
def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
    batch[:] = Wait.apply(prev_stream, next_stream, *batch)

Wait 也拓展了torch.autograd.Function,就是应用wait_stream完成业务,一个流等待另外一个流完成。

代码语言:javascript复制
class Wait(torch.autograd.Function):
    """Synchronizes a stream to another stream.

    Place it just before you want to start an operation on the next stream,
    provided that all operations on the previous stream are done.

    """
    @staticmethod
    def forward(ctx: Context,  # type: ignore
                prev_stream: AbstractStream,
                next_stream: AbstractStream,
                *input: Tensor,
                ) -> Tensors:
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream

        wait_stream(next_stream, prev_stream)

        return tuple(x.detach() for x in input)

    @staticmethod
    def backward(ctx: Context,
                 *grad_input: Tensor,
                 ) -> Tuple[Optional[Tensor], ...]:
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream

        wait_stream(prev_stream, next_stream)

        grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
        return grad_streams   grad_input

2.6 使用

2.6.1 总体概念

我们先给出一个注释中的流程图,大家有一个整体概念。

代码语言:javascript复制
        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
2.6.2 构建拷贝流

在 GPipe 类中,生成了拷贝专用流。

代码语言:javascript复制
    def forward(self, input: TensorOrTensors) -> TensorOrTensors:  # type: ignore

        ......

        # Separate CUDA streams for copy.
        copy_streams = self._ensure_copy_streams() # 这里会生成拷贝转专用流

        # The micro-batch index where the checkpointing stops.

        # Run pipeline parallelism.
        pipeline = Pipeline(batches,
                            self.partitions,
                            self.devices,
                            copy_streams,
                            self._skip_layout,
                            checkpoint_stop)
        pipeline.run()

        ...... 

_ensure_copy_streams 代码如下,就是针对每一个设备的每一个macro-batch,都生成了一个专用流:

代码语言:javascript复制
    def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
        """Ensures that :class:`GPipe` caches CUDA streams for copy.

        It's worth to cache CUDA streams although PyTorch already manages a
        pool of pre-allocated CUDA streams, because it may reduce GPU memory
        fragementation when the number of micro-batches is small.

        """
        if not self._copy_streams:
            for device in self.devices:
                self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])

        return self._copy_streams

假设有3个devices,模型被分成3个子网络,小批次被分割成 4个微批次。则具体如下:

就是说 _copy_streams[i][j] 之中,i 表示 device 的序列,j 表示 batch 序列。这个顺序比较重要,马上会提到。

代码语言:javascript复制
                   ---------------------------------- 
                  | _copy_streams                    |
                  |                                  |
                  |      ----------------------      |
                  |     |                      |     |
                  |     |  [1,1] [1,2] [1,3] -------------------------------- 
                  |     |                      |     |                       |
                  |     |  [2,1] [2,2] [2,3] ------------------------------------------ 
                  |     |                      |     |                       |         |
 ------------------------- [3,1] [3,2] [3,3]   |     |                       |         |
|                 |     |                      |     |                       |         |
|                 |      ----------------------      |                       |         |
|                 |                                  |                       |         |
|                  ----------------------------------                        |         |
|                                                                            |         |
|                                                                            v         |
|    ------------------------------------------------------------------------ ------   |
|   | Stream of device 1, Stream of device 1, Stream of device 1, Stream of device 1|  |
|    -------------------------------------------------------------------------------   |
|                                                                                      |
|    -------------------------------------------------------------------------------   |
|   | Stream of device 2, Stream of device 2, Stream of device 2, Stream of device 2 <- 
|    ------------------------------------------------------------------------------- 
|
|    ------------------------------------------------------------------------------- 
 -->  Stream of device 3, Stream of device 3, Stream of device 3, Stream of device 3|
     ------------------------------------------------------------------------------- 
2.6.3 并行操作

我们以 实例看看如何并行操作,具体要看下面 stream 的使用。

Pipeline 类的 run 方法中,有如下代码保证并行操作:

代码语言:javascript复制
def run(self) -> None:
    with spawn_workers(devices) as (in_queues, out_queues):
        for schedule in clock_cycles(m, n):
            self.fence(schedule, skip_trackers)
            self.compute(schedule, skip_trackers, in_queues, out_queues)

每次计算之前,都会用 fence 方法来把数据从前一个设备拷贝到后一个设备。

2.6.4 预先拷贝

fence 方法做了预先拷贝操作,其中会做如下操作:

  • 设定依赖关系,这个我们在前文中分析过。
  • 得到下一个设备的拷贝流。
  • 得到上一个设备的拷贝流。
  • 拷贝前面流到后续流。
代码语言:javascript复制
    def fence(self,
              schedule: List[Tuple[int, int]],
              skip_trackers: List[SkipTrackerThroughPotals],
              ) -> None:
        """Copies micro-batches after computation for the previous
        micro-batches.
        """
        batches = self.batches
        copy_streams = self.copy_streams
        skip_layout = self.skip_layout

        for i, j in schedule:
            # Ensure that batches[i-1] is executed after batches[i] in
            # backpropagation by an explicit dependency.
            if i != 0:
                depend(batches[i-1], batches[i]) # 设定依赖关系

            next_stream = copy_streams[j][i] # 得到下一个设备的拷贝流,注意,这里和for的i,j相反

            for prev_j, ns, name in skip_layout.copy_policy(j): # 因为篇幅原因,我们不分析这部分
                prev_stream = copy_streams[prev_j][i] # 拷贝前面流到后续流
                skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)

            if j != 0: # 
                prev_stream = copy_streams[j-1][i] # 得到上一个设备的拷贝流
                copy(batches[i], prev_stream, next_stream) # 拷贝前面流到后续流

我们按照之前文章的例子来看看,下面是一个schedule 生成序列。

代码语言:javascript复制
m=4 # m: number of micro-batches
n=3 # n: number of partitions
for k in range(m   n - 1):
    print( [(k - j   1 , j  1 ) for j in range(max(1   k - m, 0), min(1   k, n))] )

打印是:
[(1, 1)]                  # 第 1 轮训练计划 & 数据
[(2, 1), (1, 2)]          # 第 2 轮训练计划 & 数据
[(3, 1), (2, 2), (1, 3)]  # 第 3 轮训练计划 & 数据
[(4, 1), (3, 2), (2, 3)]  # 第 4 轮训练计划 & 数据
[(4, 2), (3, 3)]          # 第 5 轮训练计划 & 数据
[(4, 3)]                  # 第 6 轮训练计划 & 数据

前 6 个周期对应了如下时间流,第一个时钟周期 (1,1) 进入系统,第二个周期 (2,1) 进入系统 .....

代码语言:javascript复制
                                                                              
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
 cuda:0    |  (1,1)   |   (2,1)  |  (3,1)   |   (4,1)  |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
 cuda:1    |          |   (1,2)  |  (2,2)   |   (3,2)  |  (4,2)   |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
 cuda:2    |          |          |  (1,3)   |   (2,3)  |  (3,3)   |  (4,3)   |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           |          |          |          |          |          |          |
           | clock 1  |  clock 2 |  clock 3 |  clock 4 |  clock 5 |  clock 6 |
                                                                              

 ------------------------------------------------------------------------------>  Time

我们以如下计划看看,重点是第 3 个时钟周期完成的任务。

第 2 个时钟周期完成了如下操作。

代码语言:javascript复制
[(2, 1), (1, 2)]         # 第 2 轮训练计划 & 数据

第 3 个时钟周期的计划如下:

代码语言:javascript复制
[(3, 1), (2, 2), (1, 3)] # 第 3 轮训练计划 & 数据

就是对 schedule 的每个 i, j,都分别拷贝 copy_streams[j-1][i]copy_streams[j][i]

注意 我们之前的提到的,_copy_streams[i][j] 之中,i 表示 device 的序列,j 表示 batch 序列,和schedule 的 i,j 恰好相反。

所以对于我们例子,在第 3 个时钟周期内的拷贝操作是 (这里 i 和 j 在循环和后续数组提取时候是相反,这个恰好和schedule对应,于是负负得正,最终 i, j 可以对应上):

  • 对于 (3, 1),这个是新数据进入了 device 1,不需要拷贝。
  • 对于 (2, 2),拷贝是 (2,1) 到 (2,2)。
  • 对于 (1, 3),拷贝是 (1,2) 到 (1,3)。

具体如下图所示,这几个拷贝可以并行操作,因为拷贝流不是运行计算的缺省流,所以也可以和计算并行

代码语言:javascript复制
                                                                                           
         |             |            |             |            |            |             |
 cuda:0  |    (1,1)    |   (2,1)    |   (3,1)     |   (4,1)    |            |             |
         |             |            |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |      ------------        |            |            |             |
         |             |            |     |       |            |            |             |
         |             |            |     |       |            |            |             |
         |             |            |     |       |            |            |             |
         |             |            |     v       |            |            |             |
         |             |            |             |            |            |             |
 cuda:1  |             |   (1,2)    |   (2,2)     |   (3,2)    |  (4,2)     |             |
         |             |            |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |     |      |             |            |            |             |
         |             |      -----------         |            |            |             |
         |             |            |    |        |            |            |             |
         |             |            |    |        |            |            |             |
         |             |            |    |        |            |            |             |
         |             |            |    v        |            |            |             |
 cuda:2  |             |            |   (1,3)     |   (2,3)    |  (3,3)     |     (4,3)   |
         |             |            |             |            |            |             |
         |             |            |             |            |            |             |
         |             |            |             |            |            |             |
         |   clock 1   |  clock 2   |   clock 3   |  clock 4   |  clock 5   |     clock 6 |
                                                                                           

 ----------------------------------------------------------------------------------->  Time
2.6.5 计算

compute 完成了如下步骤:

  • 使用 wait(batch, copy_streams[j][i], streams[j]) "拷贝流"同步到"计算流",确保拷贝操作完成。
  • 其次进行计算。
  • 使用 wait(batch, streams[j], copy_streams[j][i]) 把计算结果从"计算流"同步到"拷贝流",确保计算操作完成。

具体如下:

代码语言:javascript复制
    def compute(self,
                schedule: List[Tuple[int, int]],
                skip_trackers: List[SkipTrackerThroughPotals],
                in_queues: List[InQueue],
                out_queues: List[OutQueue],
                ) -> None:
        """Runs tasks with synchronization to copy streams."""
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        copy_streams = self.copy_streams
        checkpoint_stop = self.checkpoint_stop

        n = len(partitions)
        streams = [current_stream(d) for d in devices]
        exc_info: Optional[ExcInfo] = None

        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
        for i, j in schedule:
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)
            if j != 0:
                wait(batch, copy_streams[j][i], streams[j])

            # Determine whether checkpointing or not.
            checkpoint = (i < checkpoint_stop)
            if checkpoint:
                def function(input: TensorOrTensors,
                             partition: nn.Sequential = partition,
                             skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                             ) -> TensorOrTensors:
                    with use_skip_tracker(skip_tracker):
                        return partition(input)

                chk = Checkpointing(function, batch)
                task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
                del function, chk

            else:
                def compute(batch: Batch = batch,
                            partition: nn.Sequential = partition,
                            skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                            ) -> Batch:
                    with use_skip_tracker(skip_tracker):
                        return batch.call(partition)

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            in_queues[j].put(task)

        # 这里进行了同步操作    
        for i, j in schedule:
            ok, payload = out_queues[j].get()

            # Hold the first exception.
            if exc_info is not None:
                continue
            elif not ok:
                exc_info = cast(ExcInfo, payload)
                continue

            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)
            if j != n-1:
                wait(batch, streams[j], copy_streams[j][i]) # 这里有同步

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)
            with use_device(devices[j]):
                task.finalize(batch)

            batches[i] = batch

        # Fail at the first exception.
        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

针对论文就是:

0x03 重计算

我们接下来看看重计算,在论文中是 Autograd Functions with Shared Memory 这节。

因为之前在 GPipe 之中我们介绍过类似部分,所以这里只是为了行文完整性而加入,故此分析较略。

3.1 解析

到目前为止,在本节中,我们没有讨论在使用梯度检查点时,如何安排重新计算任务

F^{'}_{i,j}

。当使用 checkpointing,那么它必须在反向传播任务

B_{i,j}

之前 和 完成

B_{i 1,j}

之后被调度。这就要求必须在autograd引擎和在计算图中对其进行编码。PyTorch通过实现检查点的内部 autograd 方法来支持此类功能。

PyTorch中的检查点是通过定义一个autograd函数来实现的,该函数像普通函数一样计算,即进行前向传播,不存储中间激活值,而是存储输入。在向后传递中,此函数通过使用存储的输入重新计算此函数来构造后向传播的局部计算图,并通过在局部图中反向传播来计算梯度。

然而,这把

F^{'}_{i,j}

B_{i,j}

紧密地结合在一起,我们希望在

F^{'}_{i,j}

B_{i,j}

中间插入一些指令,这些指令实现了一个等待操作,等待把

B_{i,j 1}

结果

dx^j_j

从设备

j 1

拷贝到设备

j

。这样可以允许

F^{'}_{i,j}

和复制同时发生。

对于这种细粒度的顺序控制,torchgpipe把checkpointing 操作改为使用两个单独的autograd函数Checkpoint和Recompute来实现。在任务

F^{'}_{i,j}

的执行时间之内,生成一对具有共享内存的Checkpoint和Recompute。该共享内存在向后传播中被使用,用于将通过执行Recompute生成的本地计算图传输到Checkpoint来进行反向传播。

通过安排这些函数,在每次后向传播之中,会做:

F^{'}_{i,j}

  • 一个同步操作,用来接受
dx^j_j

B_{i,j}

这三个操作可以按顺序执行,就能确保可以同时进行重新计算和复制。

我们可以通过源码来看看。

3.2 封装API

torchgpipe/checkpoint.py 之中有一个 checkpoint 方法,这是对外提供了一个简单API。

代码语言:javascript复制
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
    """Makes a checkpoint with a simple interface like
    :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
    :class:`Checkpoint` and :class:`Recompute` without boilerplate.
    """
    batch = Batch(input)

    chk = Checkpointing(function, batch)
    batch = chk.checkpoint()
    chk.recompute(batch)

    return batch.tensor_or_tensors

具体使用参见tests/test_checkpoint.py。其通过log的巧妙打印,可以让我们看出来运行时候,checkpoint在前向后向传播之中的使用。

timeline 最后结果是 ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"],

其中两两一组,分别对应了 forward pass ,Checkpoint(Log[b]),Checkpoint(Log[a])。

代码语言:javascript复制
@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f"{name}:forward")
            return x.detach()

        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f"{name}:backward")
            return None, grad_output

    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)

    # Increase the next function sequence number.
    _ = a   1   2   3   4   5

    # 这里意味着最后 backward 实际会运行"a:forward", "a:backward"
    a = checkpoint(partial(Log.apply, "a"), a)

    a, phony = fork(a)
    b = join(b, phony)

    # 这里意味着最后 backward 实际会运行"b:forward", "b:backward"
    b = checkpoint(partial(Log.apply, "b"), b)

    c = torch.cat((a, b))

    out = c.sum()

    #                         --> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^----------------------------- 
    #                         --> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()

    assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
    #    |----------------------|  |-----------------------|  |-----------------------|
    #          forward pass            Checkpoint(Log[b])         Checkpoint(Log[a])

checkpoint API 调用了 Checkpointing,所以我们看看其实现。

其实现是提供了 checkpoint 和 recompute 两个方法。分别调用了两个类。

代码语言:javascript复制
class Checkpointing:
    """Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""

    def __init__(self, function: Function, batch: Batch) -> None:
        self.function = function
        self.batch = batch

        # Shared memory between Checkpoint and Recompute. 1-length deque is
        # used for mutability and length limitation.
        self.recomputed: Deque[Recomputed] = deque(maxlen=1)
        self.rng_states: Deque[RNGStates] = deque(maxlen=1)

    def checkpoint(self) -> Batch:
        """Returns a batch applied by :class:`Checkpoint`."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # Use a phony which requires grad to ensure that Checkpoint can be
        # tracked by the autograd engine even when none of the input tensors
        # require grad.
        phony = get_phony(self.batch[0].device, requires_grad=True)

        output = Checkpoint.apply(phony, self.recomputed, self.rng_states,
                                  self.function, input_atomic, *input)
        return Batch(output)

    def recompute(self, batch: Batch) -> None:
        """Applies :class:`Recompute` to the batch in place."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # batch[0] is always requiring grad, because it has been passed
        # checkpoint with a phony requiring grad.
        batch[0], phony = fork(batch[0])
        phony = Recompute.apply(phony, self.recomputed, self.rng_states,
                                self.function, input_atomic, *input)
        batch[0] = join(batch[0], phony)

3.3 实现

Checkpoint 和下面的 Recompute 就是把普通模式下的 checkpoint 代码分离成两个阶段(forward函数被分成两段,backward 函数也被分成两段),从而可以更好的利用流水线。

对应论文中就是:

我们希望在

F^{'}_{i,j}

B_{i,j}

中间插入一些指令,这些指令实现了一个等待操作,等待把

B_{i,j 1}

结果

dx^j_j

从设备

j 1

拷贝到设备

j

。这样可以允许

F^{'}_{i,j}

和复制同时发生。

对于这种细粒度的顺序控制,torchgpipe把checkpointing 操作改为使用两个单独的autograd函数Checkpoint和Recompute来实现。在任务

F^{'}_{i,j}

的执行时间之内,生成一对具有共享内存的Checkpoint和Recompute。该共享内存在向后传播中被使用,用于将通过执行Recompute生成的本地计算图传输到Checkpoint来进行反向传播。

3.3.1 Checkpoint
代码语言:javascript复制
class Checkpoint(torch.autograd.Function):
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> TensorOrTensors:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        # 存RNG状态
        save_rng_states(input[0].device, ctx.rng_states)

        ctx.function = function
        ctx.input_atomic = input_atomic
        # 为BP做准备,其实目前没有实现
        ctx.save_for_backward(*input)

        # 进行前向计算
        with torch.no_grad(), enable_checkpointing():
            output = function(input[0] if input_atomic else input)

        return output

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:  # pragma: no cover
        # 从保存的重计算变量中弹出所需变量
        output, input_leaf = ctx.recomputed.pop() 

        if isinstance(output, tuple):
            tensors = output
        else:
            tensors = (output,)
            
        if any(y.requires_grad for y in tensors):
            tensors = tuple([x for x in tensors if x.requires_grad])
            # 进行自动微分
            torch.autograd.backward(tensors, grad_output)

        grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
        grad_input.extend(x.grad for x in input_leaf)
        return tuple(grad_input)
3.3.2 Recompute

Recompute 就是依据保存的信息,重新计算中间变量。

代码语言:javascript复制
class Recompute(torch.autograd.Function):
  
    @staticmethod
    # type: ignore[override]
    def forward(
        ctx: Context,
        phony: Tensor,
        recomputed: Deque[Recomputed],
        rng_states: Deque[RNGStates],
        function: Function,
        input_atomic: bool,
        *input: Tensor,
    ) -> Tensor:
        ctx.recomputed = recomputed
        ctx.rng_states = rng_states

        ctx.function = function
        ctx.input_atomic = input_atomic
        ctx.save_for_backward(*input)

        return phony

    @staticmethod
    def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]:  
        input = ctx.saved_tensors
        input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)

        # 取出保存的RNG状态,进行前向计算,得到中间变量
        with restore_rng_states(input[0].device, ctx.rng_states):
            with torch.enable_grad(), enable_recomputing():
                output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)

        # 保存变量,为Checkpoint使用
        ctx.recomputed.append((output, input_leaf))

        grad_input: List[None] = [None, None, None, None, None]
        grad_input.extend(None for _ in ctx.saved_tensors)
        return tuple(grad_input)

3.4 总体调用

总体调用代码如下:

代码语言:javascript复制
    def compute(self,
                schedule: List[Tuple[int, int]],
                skip_trackers: List[SkipTrackerThroughPotals],
                in_queues: List[InQueue],
                out_queues: List[OutQueue],
                ) -> None:
        """Runs tasks with synchronization to copy streams."""
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        copy_streams = self.copy_streams
        checkpoint_stop = self.checkpoint_stop

        n = len(partitions)
        streams = [current_stream(d) for d in devices]
        exc_info: Optional[ExcInfo] = None

        # With checkpointing, the autograd graph looks like this diagram:
        # ┌─────┸──────┐
        # │    Copy    │
        # └─────┰──────┘   (fence)
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        #       ┃          (compute)
        # ┌─────┸──────┐
        # │    Wait    │ [1] Synchronize the current stream with the copy stream.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │ Checkpoint │ [2] Compute a partition within checkpointing.
        # └─────┰──────┘
        # ┌─────┸──────┐
        # │    Wait    │ [3] Synchronize the copy stream with the current stream.
        # └─────┰──────┘
        #       ┠ ─ ─ ─ ┐
        #       ┃ ┌─────┴─────┐
        #       ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
        #       ┃ └─────┬─────┘
        #       ┠ ─ ─ ─ ┘
        #       ┃
        # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
        # ┌─────┸──────┐   (fence)
        # │    Copy    │
        # └─────┰──────┘
        for i, j in schedule:
            batch = batches[i]
            partition = partitions[j]

            # Synchronize with the copied input. ([1] in the diagram)
            if j != 0:
                wait(batch, copy_streams[j][i], streams[j])

            # Determine whether checkpointing or not.
            checkpoint = (i < checkpoint_stop)
            if checkpoint:
                def function(input: TensorOrTensors,
                             partition: nn.Sequential = partition,
                             skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                             ) -> TensorOrTensors:
                    with use_skip_tracker(skip_tracker):
                        return partition(input)

                chk = Checkpointing(function, batch)
                task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
                del function, chk

            else:
                def compute(batch: Batch = batch,
                            partition: nn.Sequential = partition,
                            skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
                            ) -> Batch:
                    with use_skip_tracker(skip_tracker):
                        return batch.call(partition)

                task = Task(streams[j], compute=compute, finalize=None)
                del compute

            # Compute tasks in parallel. ([2] in the diagram)
            in_queues[j].put(task)

        for i, j in schedule:
            ok, payload = out_queues[j].get()

            # Hold the first exception.
            if exc_info is not None:
                continue
            elif not ok:
                exc_info = cast(ExcInfo, payload)
                continue

            task, batch = cast(Tuple[Task, Batch], payload)

            # The copy stream synchronizes to copy the output. ([3] in the
            # diagram)
            if j != n-1:
                wait(batch, streams[j], copy_streams[j][i])

            # Finalize tasks. If checkpointing is enabled, here the
            # recomputation is scheduled at backpropagation. ([4] in the
            # diagram)
            with use_device(devices[j]):
                task.finalize(batch)

            batches[i] = batch

        # Fail at the first exception.
        if exc_info is not None:
            raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

至此,PyTorch 流水线并行分析完毕,我们接下来的计划是把PyTorch 并行训练再系统梳理一下,首先需要分析其梯度相关基础知识,敬请期待。

0xFF 参考

Markdown公式用法大全

markdown中公式编辑教程

https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html#stream-sync-behavior

CUDA学习:基础知识小结

CUDA随笔之Stream的使用

NVIDIA解决方案架构师深度解析大规模参数语言模型Megatron-BERT

Accelerating Wide & Deep Recommender Inference on GPUs

HugeCTR: High-Performance Click-Through Rate Estimation Training

https://discuss.pytorch.org/t/how-to-prefetch-data-when-processing-with-gpu/548

https://github.com/NVIDIA/apex/

https://github.com/justheuristic/prefetch_generator

https://pytorch.org/tutorials/intermediate/model_parallel_turotial.html

https://pytorch.org/docs/stable/autograd.html

https://pytorch.org/docs/notes/cuda.html

https://zhuanlan.zhihu.com/p/61765561

https://pytorch.apachen.org/docs/1.7/64.html

https://zhidx.com/p/217999.html

0 人点赞