Pytorch自动混合精度训练模板

2023-08-25 13:16:04 浏览数 (1)

GitHub 仓库地址:https://github.com/yanqiangmiffy/amp-pytorch

Pytorch自动混合精度训练模板

使用 pytorch 的自动混合精度教程。基于 PyTorch 1.6 Official Features (Automatic Mixed Precision) ,使用自定义数据集实现分类代码库

1 实验设置

1.1 环境设置

代码语言:javascript复制
- Pytorch>=1.6.0 支持CUDA
- 本地实验环境为:RTX 3090 24G

1.2 下载数据集 Kaggle 英特尔图像分类 数据集地址:Intel Image Classification

该数据包含大约 25k 张大小为 150x150 的图像,分布在 6 个类别下。 {‘建筑物’ -> 0, ‘森林’ -> 1, ‘冰川’ -> 2, ‘山’ -> 3, ‘海’ -> 4, ‘街道’ -> 5 }

数据集解压直接放在data目录下

代码语言:javascript复制
data:
    seg_pred
    seg_test
    seg_train

2 Baseline训练

实验设置:

  • ImageNet Pretrained ResNet-18 from torchvision.models
  • Batch Size 256 / Epochs 120 / Initial Learning Rate 0.0001
  • Training Augmentation: Resize((256, 256)), RandomHorizontalFlip()
  • Adam Cosine Learning rate scheduling with warmup

运行命令

代码语言:javascript复制
python main.py --checkpoint_name baseline

3 混合精度训练

未使用混合精度训练:

代码语言:javascript复制
for batch_idx, (inputs, labels) in enumerate(data_loader):
  self.optimizer.zero_grad()

  outputs = self.model(inputs)
  loss = self.criterion(outputs, labels)

  loss.backward()
  self.optimizer.step()

使用混合精度训练

代码语言:javascript复制
scaler = torch.cuda.amp.GradScaler()

for batch_idx, (inputs, labels) in enumerate(data_loader):
  self.optimizer.zero_grad()

  with torch.cuda.amp.autocast():
    outputs = self.model(inputs)
    loss = self.criterion(outputs, labels)

  # Scales the loss, and calls backward() 
  # to create scaled gradients 
  self.scaler.scale(loss).backward()

  # Unscales gradients and calls 
  # or skips optimizer.step() 
  self.scaler.step(self.optimizer)

  # Updates the scale for next iteration 
  self.scaler.update()

运行命令

代码语言:javascript复制
python main.py --checkpoint_name baseline_amp --amp;

4 实验性能

B : Baseline (FP32) AMP : Automatic Mixed Precision Training (AMP)

Algorithm

Test Accuracy

GPU Memory

Total Training Time

B - 3090 Ti

94.17

13.0G

(44s*20epochs)~=15mins

AMP - 3090 Ti

94.23

10.6G

(33s*20eochs)~=11mins

5 代码参考

  • automatic-mixed-precision-tutorials-pytorch:https://github.com/hoya012/automatic-mixed-precision-tutorials-pytorch

代码主要来自该仓库,作者实现以及项目代码很完善了,主要将其拆分出来,可以灵活应用

  • Baseline Code: https://github.com/hoya012/carrier-of-tricks-for-classification-pytorch
  • Gradual Warmup Scheduler: https://github.com/ildoonet/pytorch-gradual-warmup-lr
  • PyTorch Automatic Mixed Precision: https://pytorch.org/docs/stable/notes/amp_examples.html

0 人点赞