一般pytorch需要用户自定义训练循环,可以说有1000个pytorch用户就有1000种训练代码风格。
从实用角度讲,一个优秀的训练循环应当具备以下特点。
代码简洁易懂 【模块化、易修改、short-enough】
支持常用功能 【进度条、评估指标、early-stopping】
经过反复斟酌测试,我精心设计了仿照keras风格的pytorch训练循环,完全满足以上条件。
该方案在知乎受到许多读者喜爱,目前为止获得了超过600个赞。
知乎完整回答链接:《深度学习里面,请问有写train函数的模板吗?》
https://www.zhihu.com/question/523869554/answer/2633479163
以上pytorch模型训练模版也是我开源的一个pytorch模型训练工具 torchkeras库的核心代码。
https://github.com/lyhue1991/torchkeras
铛铛铛铛,torchkeras加入新功能啦。
最近,通过引入HuggingFace的accelerate库的功能,torchkeras进一步支持了 多GPU的DDP模式和TPU设备上的模型训练。
这里给大家演示一下,非常强大和丝滑。
公众号算法美食屋后台回复关键词:训练模版,获取本文B站视频演示和notebook源代码。
代码语言:javascript复制#从git安装最新的accelerate仓库
!pip install git https://github.com/huggingface/accelerate
一,torchkeras源码解析
torchkeras的核心代码在 下面这个文件中。
https://github.com/lyhue1991/torchkeras/blob/master/torchkeras/kerasmodel.py
代码语言:javascript复制import sys,datetime
from tqdm import tqdm
from copy import deepcopy
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator
def colorful(obj,color="red", display_type="plain"):
color_dict = {"black":"30", "red":"31", "green":"32", "yellow":"33",
"blue":"34", "purple":"35","cyan":"36", "white":"37"}
display_type_dict = {"plain":"0","highlight":"1","underline":"4",
"shine":"5","inverse":"7","invisible":"8"}
s = str(obj)
color_code = color_dict.get(color,"")
display = display_type_dict.get(display_type,"")
out = '