[源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

2021-08-30 12:24:44 浏览数 (1)

源码解析 深度学习流水线并行GPipe (2) ----- 梯度累积

目录

  • [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积
    • 0x00 摘要
    • 0x01 概述
      • 1.1 前文回顾
    • 0x02 基本概念
      • 2.1 背景知识
      • 2.2 产生原因
      • 2.3 本质
      • 2.4 VS 数据并行
      • 2.5 解决问题
    • 0x03 PyTorch 梯度累积
      • 3.1 自动累积
      • 3.2 代码示例
      • 3.3 DistributedDataParallel 的梯度累积
        • 3.3.1 单卡模型梯度累计
        • 3.3.2 DDP如何加速
        • 3.3.3 no_sync实现
    • 0x04 Tensorflow实现
    • 0x05 Gpipe实现
      • 5.1 优化器
      • 5.2 包装器
      • 5.3 应用
    • 0xFF 参考

0x00 摘要

梯度累积是一种增大训练时 batch size的技术,在本地使用 micro-batch 多次进行正向和反向传播积累梯度后,再进行梯度规约和优化器更新,这是用来均摊通信成本的一种常用策略。本文通过几个框架/库的实现对比,让大家对这个技术有进一步的了解。

本系列其他文章如下:

[源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

0x01 概述

1.1 前文回顾

前文提到,目前分布式模型训练有几个必要并行技术:

  • 流水并行,尤其是如何自动设定流水;
  • 梯度累加(Gradient Accumulation);
  • 后向重计算;
  • 1F1B 策略(我们将采用PipeDream分析);

在前文中,我们介绍了Gpipe如何实施流水线并行技术。本文我们介绍梯度累加(Gradient Accumulation)。

0x02 基本概念

梯度累积是一种用来均摊通信成本的一种常用策略。它在本地使用 micro-batch 多次进行正向和反向传播积累梯度后,再进行梯度规约和优化器更新,相当于扩大了N倍的batch size。

2.1 背景知识

深度学习模型由许多相互连接的层组成,样本在这些层中进行传播,具体传播包含两个过程:前向(forward)过程与反向(backword)过程。

  • 前向过程是从输入计算得到输出。样本在每一步都通过前向传播进行传播,在通过所有层传播后,网络为样本生成预测,然后计算每个样本的损失值,损失值意味着 “对于这个样本,本网络错了多少?”。
  • 然后就是反向过程。神经网络在此过程中计算这些损失值相对于模型参数的梯度。可以认为着就是一个梯度累积的过程。
  • 最后,这些梯度用于计算各个模型参数的更新。

训练中,每个样本的大小由超参数batch size指定,此参数的大小会对最终的模型效果产生很大的影响。一定条件下,batch size设置的越大,模型就会越稳定。

2.2 产生原因

累加梯度顾名思义就是累加后的梯度值。为什么要累加呢?因为运行内存不够用。

在训练模型时,如果一次性将所有训练数据输入到模型,经常会造成内存不足,这时候就需要把一个大 Batch 拆分成若干小批次数据(专业术语为mini-batch)。分成小批次后,带来一个问题,那就是本来应该是所有数据全部送入后计算梯度再更新参数,现在成了每个小批次都要计算梯度更新参数,为了不这么频繁计算梯度,于是就引入了累加梯度。也就是说:

  • 将整个dataset分成多个batch;
  • 分别将每个batch分成多个小批次,将每个小批次喂给神经网络;
  • 每个小批次虽然计算梯度,但是在每次反向传播后,先不进行优化器的迭代更新。
  • 经过若干个小批次后(即一个batch中的所有小批次),用每个小批次计算的梯度的累积和去进行优化器迭代更新参数、梯度清零的操作。

这样就跟把全部数据一次性送入模型进行训练效果一样了。

2.3 本质

梯度累加本质上就是累加 accumulation_stepsbatch_size/accumulation_steps 的梯度, 再根据累加的梯度来更新网络参数,以达到真实梯度类似batch_size 的效果。在使用时,需要注意适当的扩大学习率。

也就是说:

  • 首先将整个dataset分成多个batch,每个 batch size = 32,且假定 accumulation steps = 8
  • 因为 batch size = 32 ,太大了,单机显卡无法跑,于是我们在前向传播的时候以 batch_size = 32 / 8 = 4 来计算梯度;
  • 这样就再分别将每个batch分成多个batch size 为 4 的小批次,将每个小批次逐一喂给神经网络;
  • 每个小批次虽然计算梯度,但是在每次反向传播(在反向传播的时候,会将mean_loss也除以8)后,先不进行优化器的迭代更新。
  • 经过 accumulation steps 个小批次后(即一个batch中的所有小批次),用每个小批次计算梯度的累积和去进行优化器迭代更新参数。
  • 最后进行梯度清零的操作。
  • 处理下一个batch。

这样就跟把 32 batch size 一次性送入模型进行训练效果一样了。

具体如下,时间轴是由左自右:

代码语言:javascript复制
                                      ------------------- 
                                     |    GLOBAL BATCH    -------------------------- 
                                      -------------------                           |
                                                                                    |
                                                                                    |
  <--------------------------------------------------------------------------------- 
 |
 |
 |     --------------       --------------       --------------       -------------- 
  --> | MINI BATCH 0  ---->  MINI BATCH 1  ---->  MINI BATCH 2  ---->  MINI BATCH 3 |
       ----- --------       ------- ------       ------ -------       ------- ------ 
            |                      |                   |                     |
            |                      |                   |                     |
            |                      |                   |                     |
            v                      v                   v                     v
        ---- -----            ----- -----         ----- -----            ---- ----- 
       |  grad 0  |          |  grad 1   |       |  grad 2   |          |  grad 3  |
        ---- -----            ----- -----         ----- -----            ---- ----- 
            |                      |                   |                     |
            |                      |                   |                     |
            |                      |                   |                     |
            v                      v                   v                     v
      ------ ---------------------- ------------------- --------------------- ------ 
     |                                                                              |
     |                              GLOBAL BATCHGRADIENTS                           |
     |                                                                              |
      ------------------------------------------------------------------------------ 


 ------------------------------------------------------------------------------------>
                                                                        Time

2.4 VS 数据并行

micro-batch 跟数据并行有高度的相似性:

  • 数据并行是空间上的,数据被拆分成多个 tensor,同时喂给多个设备并行计算,然后将梯度累加在一起更新。
  • micro-batch 是时间上的数据并行,数据被拆分成多个 tensor,这些 tensor 按照时序依次进入同一个设备串行计算,然后将梯度累加在一起更新。

当总的 batch size 一致,且数据并行的并行度和 micro-batch 的累加次数相等时,数据并行和 Gradient Accumulation 在数学上完全等价。

Gradient Accumulation 通过多个 micro-batch的梯度累加使得下一个 micro-batch 的前向计算不需要依赖上一个 micro-batch 的反向计算,因此可以畅通无阻的进行下去(当然在一个大 batch 的最后一次 micro-batch 还是会触发这个依赖)。

2.5 解决问题

Gradient Accumulation 解决了很多问题:

  • 在单卡下,Gradient Accumulation 可以将一个大的 batch size 拆分成等价的多个小 micro-batch ,从而达到节省显存的目的。
  • 在数据并行下,Gradient Accumulation 解决了反向梯度同步开销占比过大的问题(随着机器数和设备数的增加,梯度的 AllReduce 同步开销也加大),因为梯度同步变成了一个稀疏操作,因此可以提升数据并行的加速比。
  • 在流水线并行下, Gradient Accumulation 使得不同 stage 之间可以并行执行不同的 micro-batch,通过多个 micro-batch的梯度累加使得下一个 micro-batch 的前向计算不需要依赖上一个 micro-batch 的反向计算,因此从而让各个阶段的计算不阻塞,可以畅通无阻的进行下去(当然在一个大 batch 的最后一次 micro-batch 还是会触发这个依赖), 达到流水线的目的。

0x03 PyTorch 梯度累积

3.1 自动累积

PyTorch默认会对梯度进行累加。即,PyTorch会在每一次backward()后进行梯度计算,但是梯度不会自动归零,如果不进行手动归零的话,梯度会不断累加.

至于为什么PyTorch有这样的特点,https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/9 这里给出了一个解释。我们结合其他的解释大致得出如下:

  • 从PyTorch的设计原理上来说,在每次进行前向计算得到预测值时,会产生一个用于梯度回传的计算图,这张图储存了进行反向传播需要的中间结果,当调用了.backward()后,会从内存中将这张图进行释放。
  • 利用梯度累加,可以在最多保存一张计算图的情况下进行多任务的训练。在多任务中,对前面共享的张量进行了多次计算操作后,调用不同任务的backward(),那些张量的梯度会自动累加。
  • 另外一个理由就是在内存大小不够的情况下叠加多个batch的grad作为一个大batch进行迭代,因为二者得到的梯度是等价的。
  • 由于PyTorch的动态图和autograd机制,导致并没有一个确切的点知道何时停止前向操作,因为你不知道什么时候一个计算会结束以及什么时候又会有一个新的开始。所以自动设置梯度为 0 比较棘手。

3.2 代码示例

下面给出一个传统代码示例:

代码语言:javascript复制
for i,(images,target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images)
    loss = criterion(outputs,target)
  
    # 2. backward
    optimizer.zero_grad()   # reset gradient
    loss.backward()
    optimizer.step()

然后给出一个梯度累积示例:

  • 获取loss: 输入图像和标签,通过计算得到预测值,计算损失函数;
  • loss.backward()反向传播,计算当前梯度;
  • 多次循环步骤 1-2, 不清空梯度,使梯度累加在已有梯度上;
  • 梯度累加一定次数后,先optimizer.step()根据累积的梯度更新网络参数,然后optimizer.zero_grad()清空过往梯度,为下一波梯度累加做准备;
代码语言:javascript复制
for i, (images, target) in enumerate(train_loader):
    # 1. input output
    images = images.cuda(non_blocking=True)
    target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
    outputs = model(images) # 前向传播
    loss = criterion(outputs, target) # 计算损失

    # 2. backward
    loss.backward() # 反向传播,计算当前梯度
    
     # 3. update parameters of net
    if ((i 1)�cumulation)==0:
        # optimizer the net
        optimizer.step() # 更新网络参数
        optimizer.zero_grad() # reset grdient # 清空过往梯度

3.3 DistributedDataParallel 的梯度累积

DistributedDataParallel(DDP)在module级别实现数据并行性。其使用torch.distributed包communication collectives来同步梯度,参数和缓冲区。并行性在单个进程内部和跨进程均有用。

在这种情况下,虽然gradient accumulation 也一样可以应用,但是为了提高效率,需要做相应的调整。

3.3.1 单卡模型梯度累计

我们首先回忆单卡模型,即普通情况下如何进行梯度累加。

代码语言:javascript复制
# 单卡模式,即普通情况下的梯度累加
for data in enumerate(train_loader # 每次梯度累加循环
    optimizer.zero_grad()
    for _ in range(K):
        prediction = model(data / K)
        loss = loss_fn(prediction, label) / K
        loss.backward()  # 积累梯度,不应用梯度改变,执行K次
    optimizer.step()  # 应用梯度更新,更新网络参数,执行一次

在 loss.backward() 语句处,DDP会进行梯度规约 all_reduce。

因为每次梯度累加循环之中有K个步骤,所以有K次 all_reduce。但实际上,每次梯度累加循环中,optimizer.step()只有一次,这意味着我们这K次 loss.backward() 之中,其实只进行一次 all_reduce 即可,前面 K - 1 次 all_reduce 是没有用的

3.3.2 DDP如何加速

于是我们就思考,是否可以在 loss.backward() 之中有一个开关,使得我们在前面K-1次 loss.backward() 之中只做反向传播,不做梯度同步(累积)。

DDP 已经想到了这个问题,它提供了一个暂时取消梯度同步的context函数 no_sync()。在no_sync()context之下,DDP不会进行梯度同步。但是在no_sync()上下文结束之后的第一次 forward-backward 会进行同步。

最终代码如下:

代码语言:javascript复制
model = DDP(model)

for data in enumerate(train_loader # 每次梯度累加循环
    optimizer.zero_grad()
    
    for _ in range(K-1):# 前K-1个step 不进行梯度同步(累积梯度)。
        with model.no_sync(): # 这里实施“不操作”
            prediction = model(data / K)
            loss = loss_fn(prediction, label) / K
            loss.backward()  # 积累梯度,不应用梯度改变
    
    prediction = model(data / K)
    loss = loss_fn(prediction, label) / K 
    loss.backward()  # 第K个step 进行梯度同步(累积梯度)
    optimizer.step() # 应用梯度更新,更新网络参数  
3.3.3 no_sync实现

no_sync 的代码如下:

代码语言:javascript复制
    @contextmanager
    def no_sync(self):
        r"""
        A context manager to disable gradient synchronizations across DDP
        processes. Within this context, gradients will be accumulated on module
        variables, which will later be synchronized in the first
        forward-backward pass exiting the context.

        Example::

            >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
            >>> with ddp.no_sync():
            >>>   for input in inputs:
            >>>     ddp(input).backward()  # no synchronization, accumulate grads
            >>> ddp(another_input).backward()  # synchronize grads
        """
        old_require_backward_grad_sync = self.require_backward_grad_sync
        self.require_backward_grad_sync = False
        try:
            yield
        finally:
            self.require_backward_grad_sync = old_require_backward_grad_sync

具体如何使用?我们在 DistributedDataParallel 的 forward 方法中可以看到,只有在 require_backward_grad_sync 为 True时候,才会调用reducer.prepare_for_forward() 和 reducer.prepare_for_backward,才会把require_forward_param_sync 设置为 True。

代码语言:javascript复制
   def forward(self, *inputs, **kwargs):
    
        with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
            
            self.reducer.save_thread_local_state()
            if torch.is_grad_enabled() and self.require_backward_grad_sync:
                # True时候才会进入
                self.logger.set_runtime_stats_and_log()
                self.num_iterations  = 1
                self.reducer.prepare_for_forward()
            
            # 省略部分代码

            if torch.is_grad_enabled() and self.require_backward_grad_sync:
                # True时候才会进入
                self.require_forward_param_sync = True
                if self.find_unused_parameters and not self.static_graph:
                    # Do not need to populate this for static graph.
                    self.reducer.prepare_for_backward(list(_find_tensors(output)))
                else:
                    self.reducer.prepare_for_backward([])
            else:
                self.require_forward_param_sync = False

			# 省略部分代码

再看看 Reducer的两个方法。

prepare_for_forward 只是做统计工作,可以忽略。

代码语言:javascript复制
void Reducer::prepare_for_forward() {
  std::lock_guard<std::mutex> lock(mutex_);
  num_iterations_  ;
  if (should_collect_runtime_stats()) {
    record_forward_compute_start_time();
  }
}

prepare_for_backward 会做重置和预备工作,与梯度累积相关的是 expect_autograd_hooks_ = true

代码语言:javascript复制
void Reducer::prepare_for_backward(
    const std::vector<torch::autograd::Variable>& outputs) {
  std::lock_guard<std::mutex> lock(mutex_);

  // Reset accounting.
  expect_autograd_hooks_ = true; // 这里是关键
  reset_bucket_counting();

  // Reset unused parameter accounting.
  has_marked_unused_parameters_ = false;
  // Reset per iteration marked ready parameters.
  perIterationReadyParams_.clear();

  // If static graph is not set, search graph to detect unused parameters.
  // When static graph is set, unused_parameters_ will be detected and will
  // not change after 1st iteration.
  // If static_graph_ = false and find_unused_parameters_ is false,
  // we assume that autograd hooks for ALL variables will be called,
  // and we don't have to search the autograd graph for presence of these hooks.
  if (dynamic_graph_find_unused()) {
    unused_parameters_.clear();
    search_unused_parameters(outputs);
  }
}

expect_autograd_hooks_ = true 如何使用?在 Reducer::autograd_hook 之中有,如果不需要进行all-reduce操作,则直接返回。

代码语言:javascript复制
void Reducer::autograd_hook(VariableIndex index) {
    
  std::lock_guard<std::mutex> lock(this->mutex_);

  // Carry over thread local state from main thread. This allows for
  // thread-local flags such as profiler enabled to be configure correctly.
  at::ThreadLocalStateGuard g(thread_local_state_);

  // Ignore if we don't expect to be called.
  // This may be the case if the user wants to accumulate gradients
  // for number of iterations before reducing them.
  if (!expect_autograd_hooks_) { // 如果不需要进行all-reduce操作,则直接返回。
    return;
  }

  // 省略后续代码

有点绕,我们梳理一下:

一个 step 有两个操作:forward 和 backward。

  • forward 操作时候 :require_backward_grad_sync = True 意味着 forward 时候
    • 设置 require_forward_param_sync = True。
    • 会调用reducer.prepare_for_forward() 和 reducer.prepare_for_backward
    • reducer.prepare_for_backward 意味着会设置 expect_autograd_hooks_ = true,expect_autograd_hooks_是关键。
  • backward 操作时候
    • expect_autograd_hooks_ = true 意味着 backward 时候进行 进行all-reduce操作。
    • 否则直接返回,不做 all-reduce操作。

即如下图,

  • 上半部分是 forward 的逻辑,就是 forward()函数,
  • 下半部分是 backward 逻辑,就是 Reducer::autograd_hook() 函数。
  • expect_autograd_hooks_ 是forward 和 backward 之间串联的关键之处。
代码语言:javascript复制
forward
 --------------------------------------------------------------------------------- 
| forward()                                                                       |
|                                                                                 |
|                require_backward_grad_sync == True??  ---------                  |
|                                                               |                 |
|                             |                                 |                 |
|                             | Yes                             |                 |
|                             |                                 | No              |
|                             v                                 |                 |
|                 reducer.prepare_for_forward                   |                 |
|                                                               |                 |
|                             |                                 |                 |
|                             |                                 |                 |
|                             v                                 |                 |
|                 reducer.prepare_for_backward                  |                 |
|                                                               |                 |
|                             |                                 |                 |
|                             |                                 |                 |
|                             v                                 v                 |
|                 expect_autograd_hooks_ = true    expect_autograd_hooks_ = false |
|                                                                                 |
|                             |                                 |                 |
 --------------------------------------------------------------------------------- 
                              |                                 |
 -------------------------------------------------------------------------------- 
 backward                     |                                 |
                              |                                 |
  -------------------------------------------------------------------------------- 
 |                            |                                 |                 |
 | Reducer::autograd_hook()   |                                 |                 |
 |                            |                                 |                 |
 |                            |     ----------------------------                  |
 |                            |    |                                              |
 |                            |    |                                              |
 |                            v    v                                              |
 |                 expect_autograd_hooks_ == True??  ------------                 |
 |                                                               |                |
 |                            | Yes                              |                |
 |                            |                                  |  No            |
 |                            v                                  v                |
 |                      Do All-Reduce                          Return             |
 |                                                                                |
 |                                                                                |
  -------------------------------------------------------------------------------- 

no_sync 操作就 意味着设置 require_backward_grad_sync = False,最终设置了 expect_autograd_hooks_ = False。这样,backward 时候就不会进行 All-Reduce 操作

0x04 Tensorflow实现

在 pytorch 中,梯度只要不清零默认是累加的,于是很容易实现上述问题。但在Tensorflow中,却不那么容易。

我们从 stackoverflow 得到示例代码如下:

代码语言:javascript复制
## 定义优化器
opt = tf.train.AdamOptimizer()

## 得到你模型中的所有可训练变量
tvs = tf.trainable_variables()

# 用于记录每个变量的累积梯度,初始化为0s
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in tvs]
# 定义清零操作
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]

## 使用优化器的compute_gradients来计算梯度
gvs = opt.compute_gradients(rmse, tvs)

## 将当前梯度累加在之前定义的变量上
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(gvs)]

## 定义训练step,梯度下降,更新参数
train_step = opt.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(gvs)])

## 训练循环
while ...:
    # 使用 zero_ops 初始化
    sess.run(zero_ops)
    # 使用accum_ops对accum_vars进行'n_minibatches'次梯度累积
    for i in xrange(n_minibatches):
        sess.run(accum_ops, feed_dict=dict(X: Xs[i], y: ys[i]))
    # 使用累积的梯度进行参数更新
    sess.run(train_step)

0x05 Gpipe实现

在 GPipe 的流水并行示例中,每个“时间点” 可以在多个阶段(stage)上同时做不同的micro-batch,图中每个方块中的标号表示了第几个 micro-batch;同一个 micro-batch 还是串行的经过所有的 stage,在这种情况下,每个设备的空闲时间只有 25% 左右。

具体代码如下:

5.1 优化器

在 lingvo/core/optimizer.py 中 GradientAggregationOptimizer 中有具体实现,关键代码为apply_gradients,逻辑为:

  • 如果 _num_micro_batches 为 1,则说明不用梯度累积,直接 apply_gradients;
  • 遍历 grads_and_vars 列表,累积梯度;
  • accum_step 为梯度累积条件:
    • 如果达到了小批次迭代数目,则调用 _ApplyAndReset:
      • 调用 apply_gradients 应用梯度;
      • 调用 zero_op 清零梯度;
    • 否则就调用_Accum,实际上是 no_op不做操作;

具体代码如下:

代码语言:javascript复制
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    if self._num_micro_batches == 1:
      return self._opt.apply_gradients(grads_and_vars, global_step)
    global_step = global_step or py_utils.GetOrCreateGlobalStepVar()
    with tf.init_scope():
      self._create_slots([v for (_, v) in grads_and_vars])

    accums = []
    variables = []

    # 遍历,累积梯度
    for g, v in grads_and_vars:
      accum = self.get_slot(v, 'grad_accum')
      variables.append(v)
      # pytype: disable=attribute-error
      if isinstance(g, tf.IndexedSlices):
        scaled_grad = tf.IndexedSlices(
            g.values / self._num_micro_batches,
            g.indices,
            dense_shape=g.dense_shape)
      else:
        scaled_grad = g / self._num_micro_batches
      accum_tensor = accum.read_value()
      accums.append(accum.assign(accum_tensor   scaled_grad))
      # pytype: enable=attribute-error

    # 应用梯度,清零梯度
    def _ApplyAndReset():
      normalized_accums = accums
      if self._apply_crs_to_grad:
        normalized_accums = [
            tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums
        ]
      apply_op = self._opt.apply_gradients(
          list(zip(normalized_accums, variables)))
      with tf.control_dependencies([apply_op]):
        zero_op = [tf.assign(accum, tf.zeros_like(accum)) for accum in accums]
      return tf.group(zero_op, tf.assign_add(global_step, 1))

    # 累积函数,其实是不做操作
    def _Accum():
      return tf.no_op()

    # 梯度累积条件,如果达到了小批次迭代数目,则应用梯度,清零梯度,否则就不做操作
    accum_step = tf.cond( 
        tf.equal(
            tf.math.floormod(self._counter   1, self._num_micro_batches), 0),
        _ApplyAndReset,  # Apply the accumulated gradients and reset.
        _Accum)  # Accumulate gradients.

    with tf.control_dependencies([tf.group(accums)]):
      return tf.group(accum_step, tf.assign_add(self._counter, 1))

5.2 包装器

ShardedAdam 是给 GradientAggregationOptimizer 和 ShardedAdamOptimizer 做了包装,用户可以直接使用。

代码语言:javascript复制
class ShardedAdam(optimizer.Adam):
  """Adam optimizer wrapper that shards the slot variables."""

  @classmethod
  def Params(cls):
    params = super().Params()
    params.Define('num_micro_batches', 1, 'Number of accumulated batches.')
    return params

  def GetOptimizer(self, lr):
    p = self.params
    opt = ShardedAdamOptimizer(
        learning_rate=lr,
        beta1=p.beta1,
        beta2=p.beta2,
        epsilon=p.epsilon,
        name=p.name)
    if p.num_micro_batches > 1:
      tf.logging.info('Applying gradient aggregation.')
      
      opt = optimizer.GradientAggregationOptimizer( # 应用梯度累积
          opt, p.num_micro_batches, apply_crs_to_grad=True)
      self._cached_opt = opt
    return opt

5.3 应用

DenseLm12kWide41BAdam16x16 中有如何使用 ShardedAdam。

代码语言:javascript复制
@model_registry.RegisterSingleTaskModel
class DenseLm12kWide41BAdam16x16(DenseLm128B16x16):
  """41B params LM model with 2D split and ADAM optimizer on v3-512."""

  # Each layer has 1.6875B parameters.
  SEQUENCE_LENGTH = 2048
  NUM_DEVICES_PER_SPLIT = 512
  BATCH_DIM_PER_DEVICE = 0.5  # Total batch size 256
  DEVICE_MESH_SHAPE = [16, 32]
  DEVICE_MESH = gshard_utils.GetNonPod2dMesh(DEVICE_MESH_SHAPE, [16, 16, 2])
  NUM_TRANSFORMER_LAYERS = 24
  HIDDEN_DIM = 48 * 1024
  MODEL_DIM = 12 * 1024
  NUM_HEADS = 96
  ATTENTION_KEY_VALUE_DIM = 128
  GATED_GELU = False
  POSITIONAL_EMBEDDING = True
  NUM_MICRO_BATCHES = 1

  def Task(self):
    p = super().Task()
    
    # 使用ShardedAdam
    p.train.optimizer = ShardedAdam.Params().Set(
        beta1=0.9,
        beta2=0.999,
        epsilon=1e-6,
        num_micro_batches=self.NUM_MICRO_BATCHES)
    return p

0xFF 参考

[原创][深度][PyTorch] DDP系列第三篇:实战与技巧

0 人点赞