一个简单的Trainer项目(一)

2022-04-29 14:42:16 浏览数 (1)

整理Trainer的目的就是为了在偷懒的同时减少返工的可能,有一个好的trainer可以给我们省出不少喝茶的时间。

那么一个Trainer应该由哪些功能呢?

我认为主要有如下几个方面:

  1. 超参可配,将所有超参提取出来,使用配置文件进行配置,训练时只修改配置文件,不修改代码;
  2. 中间结果可查看,要留出调试接口,避免在调试的时候改动核心代码;
  3. 模块化,模块清晰明了,且相互不干扰。
  4. 日志可溯源,实验做多了可能喝口水就忘了刚才提交的任务是什么配置,所以训练日志里面要有尽量详细的信息;

一个基于PyTorch的Trainer由以下部分构成:

  1. 主流程

即训练、验证及模型推理调试的流程,包括forward, backward, optimizer, LRscheduler, 模型存取以及多机多卡训练等机制。

2. 数据加载

PyTorch中一般用自定义DataLoader来实现,其中包含数据增强、Sampler数据采样、collate_fn数据分批等,如果是lmdb数据,还需要用到worker_init_fn

3. 函数库

包括评价函数,loss函数等等

4. 网络模块库

包括网络中的各种layer、block、module等

5. 模型

使用各种网络模块和函数搭建起的网络,输入数据,输出loss及acc, 保证统一接口。

6. 配置文件

将包括模型在内的所有参数都提出来,写在配置文件里面,一般用cfg或者yaml,也可以利用argparse库以命令行参数的形式实现。

7. 日志模块

包括日志文件写入、屏幕输出等,包含时间和配置信息,也可以嵌入tensorboard做展示用。

0 人点赞