TinyBERT 蒸馏速度实现加速小记

2022-08-26 16:01:25 浏览数 (1)

每天给你送来NLP技术干货!


编辑:AI算法小喵

写在前面

最近做的一个 project 需要复现 EMNLP 2020 Findings 的 TinyBERT[1],本文是对复现过程对踩到坑,以及对应的解决方案和实现加速的一个记录。

1. Overview of TinyBERT

BERT 效果虽好,但其较大的内存消耗和较长的推理延时会对其上线部署造成一定挑战

内存消耗方面,一系列知识蒸馏的工作,例如 DistilBERT[2]BERT-PKD[3]TinyBERT 被提出来用以降低模型的参数(主要是层数)以及相应地减少时间;

推理加速方面,也有 DeeBERT[4]FastBERT[5]CascadeBERT[6] 等方案提出,它们动态地根据样本难度进行模型的执行从而提升推理效率。其中较具备代表性的是 TinyBERT,其核心框架如下:

分为两个阶段:

  • General Distillation:在通用的语料,例如 BookCorpus, EnglishWiki 上进行知识蒸馏;目标函数包括 Transformer Layer Attention 矩阵以及 Layer Hidden States 的对齐;
  • Task Distillation:在具体的任务数据集上进行蒸馏,进一步分成两个步骤:
    • Task Transformer Disitllation: 在任务数据集上对齐 Student 和已经 fine-tuned Teacher model 的 attention map 和 hidden states;
    • Task Prediction Distillation:在任务数据集上对 student model 和 teacher model 的 output distritbuion 利用 KL loss / MSE loss 进行对齐。

TinyBERT 提供了经过 General Distillation 阶段的 checkpoint,可以认为是一个小的 BERT,包括了 6L786H 版本以及 4L312H 版本。而我们后续的复现就是基于 4L312H v2 版本的。

值得注意的是,TinyBERT 对任务数据集进行了数据增强操作:通过基于 Glove 的 Embedding Distance 的相近词替换以及 BERT MLM 预测替换,会将原本的数据集扩增到 20 倍。而我们遇到的第一个 bug 就是在数据增强阶段。

2. Bug in Data Augmentation

我们可以按照官方给出的代码对数据进行增强操作,但是在 QNLI 上会报错:

造成数据增强到一半程序就崩溃了,为什么呢?

很简单,因为数据增强代码 BERT MLM 换词模块对于超长(> 512)的句子没有特殊处理,造成下标越界,具体可以参考 #Issue50:error occured when apply data_augmentation on QNLI and QQP dataset[7]

在对应的函数中进行边界的判断即可:

3. Acceleration of Data Parallel

当我们费劲愉快地完成数据增强之后,下一步就是要进行 Task Specific 蒸馏里的 Step 1,General Distillation 了。

对于一些小数据集像 MRPC,增广 20 倍之后的数据量依旧是 80k 不到,因此训练速度还是很快的,20 轮单卡大概半天也能跑完。但是对于像 MNLI 这样 GLUE 中最大的数据集(390k),20 倍增广后的数据集(增广就花费了大约 2 天时间),如果用单卡训练个 10 轮那可能得跑上半个月了,到时候怕不是黄花菜都凉咯。

3.1 多卡训练初步尝试

遂打算用多卡训练,一看,官方的实现就通过 nn.DataParal lel 支持了多卡。好嘛,直接 CUDA_VISIBLE_DEVICES="0,1,2,3" 来上 4 块卡。不跑不知道,一跑吓一跳:

  • 加载数据(tokenize, padding )花费 1小时;
  • 好不容易跑起来了,一开 nvidia-smi 发现 GPU 的利用率都在 50% 左右;
  • 再一看预估时间,大约 21h 一轮,10 epoch 那四舍五入就是一个半礼拜。

好家伙,这我还做不做实验了?

3.2 DDP 替换 DP

这时候就去翻看 PyTorch 文档,发现 PyTorch 现在都不再推荐使用 nn.DataParallel,为什么呢?主要原因在于:

  • DataParallel 的实现是单进程的,每次都是有一块主卡读入数据再发给其他卡,这一部分不仅带来了额外的计算开销,而且会造成主卡的 GPU 显存占用会显著高于其他卡,进而造成潜在的 batch size 限制;
  • 此外,这种模式下,其他 GPU 算完之后要传回主卡进行同步,这一步又会受限于 Python 的线程之间的 GIL(global interpreter lock),进一步降低了效率。
  • 此外,还有多机以及模型切片等 DataParallel 不支持,但是另一个 DistributedDataParallel 模块支持的功能。

所以得把原先 TinyBERT DP(DataParallel)改成 DDP(DistributedDataParallel)。把 DP 改成 DDP 可以参考知乎-当代研究生需要掌握的并行训练技巧[8]。核心的代码就是做一下初始化,以及用 DDP 替换掉 DP

然后,大功告成,一键启动:

启动成功了吗?模型又开始处理数据….

One hours later,机器突然卡住,程序的 log 也停了,打开 htop 一看:好家伙,256G 的内存都满了,程序都是 D 状态,这是咋回事?

4. Acceleration of Data Loading

我先试了少量数据,降采样到 10k,程序运行没问题, DDP 速度很快;我再尝试了单卡加载,虽然又 load 了一个小时,但是 ok,程序还是能跑起来,那么,问题是如何发生的呢?

单卡的时候我看了一眼加载全量数据完毕之后的内存占用,大约在 60G 左右,考虑到 DDP 是多进程的,因此,每个进程都要独立地加载数据,4 块卡 4个进程,大约就是 250 G 的内存,因此内存爆炸,到后面数据的 io 就卡住了(没法从磁盘 load 到内存),所以造成了程序 D 状态。

看了下组里的机器,最大的也就是 250 G 内存,也就是说,如果我只用 3 块卡,那么是能够跑的,但是万一有别的同学上来开程序吃了一部分内存,那么就很可能爆内存,然后就是大家的程序都同归于尽的局面,不太妙。

一种不太优雅的解决方案就是,把数据切块,然后读完一小块训练完,再读下一块,再训练,再读。咨询了一下组里资深的师兄,还有一种办法就是实现一种把数据存在磁盘上,每次要用的时候才 load 到内存的数据读取方案,这样就能够避免爆内存的问题。行吧,那就干吧,但是总不能从头造轮子吧?

脸折师兄提到 huggingface(yyds) 的 datasets[9] 能够支持这个功能,check 了一下文档,发现他是基于 pyarrow 的实现了一个 memory map 的数据读取[10],以我的 huggingface transformers 的经验,似乎是能够实现这个功能的,所以摩拳擦掌,准备动手。

首先,要把增广的数据 load 进来,datasets 提供的 load_dataset 函数最接近的就是 load_dataset('csv', data_file),然后我们就可以逐个 column 的拿到数据并且进行预处理了。

写了一会,发现总是报读取一部分数据后 columns 数目不对的错误,猜测可能原始 MNLI 数据集就不太能保证每个列都是在的,检查了一下 MnliProcessor 里处理的代码,发现其写死了 line[8]line[9] 作为 sentence_a 和 sentence_b。无奈之下,只能采取最粗暴地方式,用 text mode 读进来,每一行是一个数据,再 split:

写完这个 preprocess_func ,我觉得胜利在望,但还有几个小坑需要解决s:

  • map 完之后,返回的还是一个 DatasetDict,得手动取一下 train set;
  • 对于原先存在的列,map 函数并不会去除掉,所以如果不用的列,需要手动 .remove_columns()
  • 在配合 DDP 使用的时候,因为 DistributedSample 取数据的维度是在第一维取的,所以取到的数据可能是个 seq_len 长的列表,里面的 tensor 是 [bsz] 形状的,需要在交给 model 之前 stack 一下:

至此,只要把之前代码的 train_data 都换成现在的版本即可。

此外,为了进一步加速,我还把混合精度也整合了进来,现在 Pytorch 以及自带对混合精度的支持,代码量也很少,但是有个坑就是loss 的计算必须被 auto() 包裹住,同时,所有模型的输出都要参与到 loss 的计算,这对于只做 prediction 或者是 hidden state 对齐的 loss 很不友好,所以只能手动再额外计算一项为系数为 0 的 loss 项(这样他参与到训练但是不会影响梯度)。

总结

最后,改版过的代码在我的 GitHub fork[11] 版本中,我不要脸地起名为 fast_td 。实际上,改版后的有点有一下几个:

  • 数据加载方面:第一次加载/处理 780w 大约耗时 50m,但是不会多卡都消耗内存,实际占用不到 2G;同时,得益于 datasets 的支持,后续加载不会重复处理数据而是直接读取之前的 cache;
  • 模型训练方面:得益于 DDP 和 混合精度,在 MNLI 上训增强数据 10 轮,3 块卡花费的时间大约在 20h 左右,提速了 10 倍。

这次修改代码大概花了 2 天时间来实现和 debug,不过感觉收益还是挺大的,此处需要感谢任大佬 & 脸折师兄的建议,以及 andy 提供的知乎文章,撒花~

写在最后

小喵发现本文的作者也有自己的公众号【 三石杂货铺】,不过上面文章比较少。大家可以从文章来源去关注作者的博客地址,查看作者更多的博文,如研一这一年[12]EMNLP21 和 Rebuttal攻略[13],相信正在读研究生的粉丝朋友会有所收获。

还是一样,如果本文对你有帮助的话,欢迎点赞&在看&分享,这对我继续分享&创作优质文章非常重要。感谢

0 人点赞