pytorch_lightning模型训练加速技巧与涨点技巧

2023-02-23 13:10:38 浏览数 (1)

pytorch-lightning 是建立在pytorch之上的高层次模型接口。

pytorch-lightning 之于 pytorch,就如同keras之于 tensorflow.

pytorch-lightning 有以下一些引人注目的功能:

  • 可以不必编写自定义循环,只要指定loss计算方法即可。
  • 可以通过callbacks非常方便地添加CheckPoint参数保存、early_stopping 等功能。
  • 可以非常方便地在单CPU、多CPU、单GPU、多GPU乃至多TPU上训练模型。
  • 可以通过调用torchmetrics库,非常方便地添加Accuracy,AUC,Precision等各种常用评估指标。
  • 可以非常方便地实施多批次梯度累加、半精度混合精度训练、最大batch_size自动搜索等技巧,加快训练过程。
  • 可以非常方便地使用SWA(随机参数平均)、CyclicLR(学习率周期性调度策略)与auto_lr_find(最优学习率发现)等技巧 实现模型涨点。

一般按照如下方式 安装和 引入 pytorch-lightning 库。

代码语言:javascript复制
#安装
pip install pytorch-lightning
代码语言:javascript复制
#引入
import pytorch_lightning as pl 

顾名思义,它可以帮助我们漂亮(pl)地进行深度学习研究。

0 人点赞