整理Trainer的目的就是为了在偷懒的同时减少返工的可能,有一个好的trainer可以给我们省出不少喝茶的时间。
那么一个Trainer应该由哪些功能呢?
我认为主要有如下几个方面:
- 超参可配,将所有超参提取出来,使用配置文件进行配置,训练时只修改配置文件,不修改代码;
- 中间结果可查看,要留出调试接口,避免在调试的时候改动核心代码;
- 模块化,模块清晰明了,且相互不干扰。
- 日志可溯源,实验做多了可能喝口水就忘了刚才提交的任务是什么配置,所以训练日志里面要有尽量详细的信息;
一个基于PyTorch的Trainer由以下部分构成:
- 主流程
即训练、验证及模型推理调试的流程,包括forward, backward, optimizer, LRscheduler, 模型存取以及多机多卡训练等机制。
2. 数据加载
PyTorch中一般用自定义DataLoader来实现,其中包含数据增强、Sampler数据采样、collate_fn数据分批等,如果是lmdb数据,还需要用到worker_init_fn
3. 函数库
包括评价函数,loss函数等等
4. 网络模块库
包括网络中的各种layer、block、module等
5. 模型
使用各种网络模块和函数搭建起的网络,输入数据,输出loss及acc, 保证统一接口。
6. 配置文件
将包括模型在内的所有参数都提出来,写在配置文件里面,一般用cfg或者yaml,也可以利用argparse库以命令行参数的形式实现。
7. 日志模块
包括日志文件写入、屏幕输出等,包含时间和配置信息,也可以嵌入tensorboard做展示用。