知乎热门问题:深度学习里面,请问有写train函数的模板吗?
以下是 知乎用户 吃货本货 的回答。
老师,这题我会。
一般pytorch需要用户自定义训练循环,可以说有1000个pytorch用户就有1000种训练代码风格。 从实用角度讲,一个优秀的训练循环应当具备以下特点。
- 代码简洁易懂 【模块化、易修改、short-enough】
- 支持常用功能 【进度条、评估指标、early-stopping】
经过反复斟酌测试,我精心设计了仿照keras风格的pytorch训练循环。诸君且看。
代码语言:javascript复制import os,sys,time
import numpy as np
import pandas as pd
import datetime
from tqdm import tqdm
import torch
from torch import nn
from copy import deepcopy
def printlog(info):
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("n" "=========="*8 "%s"%nowtime)
print(str(info) "n")
class StepRunner:
def __init__(self, net, loss_fn,
stage = "train", metrics_dict = None,
optimizer = None
):
self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
self.optimizer = optimizer
def step(self, features, labels):
#loss
preds = self.net(features)
loss = self.loss_fn(preds,labels)
#backward()
if self.optimizer is not None and self.stage=="train":
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
#metrics
step_metrics = {self.stage "_" name:metric_fn(preds, labels).item()
for name,metric_fn in self.metrics_dict.items()}
return loss.item(),step_metrics
def train_step(self,features,labels):
self.net.train() #训练模式, dropout层发生作用
return self.step(features,labels)
@torch.no_grad()
def eval_step(self,features,labels):
self.net.eval() #预测模式, dropout层不发生作用
return self.step(features,labels)
def __call__(self,features,labels):
if self.stage=="train":
return self.train_step(features,labels)
else:
return self.eval_step(features,labels)
class EpochRunner:
def __init__(self,steprunner):
self.steprunner = steprunner
self.stage = steprunner.stage
def __call__(self,dataloader):
total_loss,step = 0,0
loop = tqdm(enumerate(dataloader), total =len(dataloader))
for i, batch in loop:
loss, step_metrics = self.steprunner(*batch)
step_log = dict({self.stage "_loss":loss},**step_metrics)
total_loss = loss
step =1
if i!=len(dataloader)-1:
loop.set_postfix(**step_log)
else:
epoch_loss = total_loss/step
epoch_metrics = {self.stage "_" name:metric_fn.compute().item()
for name,metric_fn in self.steprunner.metrics_dict.items()}
epoch_log = dict({self.stage "_loss":epoch_loss},**epoch_metrics)
loop.set_postfix(**epoch_log)
for name,metric_fn in self.steprunner.metrics_dict.items():
metric_fn.reset()
return epoch_log
def train_model(net, optimizer, loss_fn, metrics_dict,
train_data, val_data=None,
epochs=10, ckpt_path='checkpoint.pt',
patience=5, monitor="val_loss", mode="min"):
history = {}
for epoch in range(1, epochs 1):
printlog("Epoch {0} / {1}".format(epoch, epochs))
# 1,train -------------------------------------------------
train_step_runner = StepRunner(net = net,stage="train",
loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict),
optimizer = optimizer)
train_epoch_runner = EpochRunner(train_step_runner)
train_metrics = train_epoch_runner(train_data)
for name, metric in train_metrics.items():
history[name] = history.get(name, []) [metric]
# 2,validate -------------------------------------------------
if val_data:
val_step_runner = StepRunner(net = net,stage="val",
loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict))
val_epoch_runner = EpochRunner(val_step_runner)
with torch.no_grad():
val_metrics = val_epoch_runner(val_data)
val_metrics["epoch"] = epoch
for name, metric in val_metrics.items():
history[name] = history.get(name, []) [metric]
# 3,early-stopping -------------------------------------------------
arr_scores = history[monitor]
best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)
if best_score_idx==len(arr_scores)-1:
torch.save(net.state_dict(),ckpt_path)
print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,
arr_scores[best_score_idx]),file=sys.stderr)
if len(arr_scores)-best_score_idx>patience:
print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
monitor,patience),file=sys.stderr)
break
net.load_state_dict(torch.load(ckpt_path))
return pd.DataFrame(history)
使用方法如下:
代码语言:javascript复制from torchmetrics import Accuracy
loss_fn = nn.BCEWithLogitsLoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)
metrics_dict = {"acc":Accuracy()}
dfhistory = train_model(net,
optimizer,
loss_fn,
metrics_dict,
train_data = dl_train,
val_data= dl_val,
epochs=10,
patience=5,
monitor="val_acc",
mode="max")
疗效如下:
代码语言:javascript复制================================================================================2022-07-10 20:06:16
Epoch 1 / 10
100%|██████████| 200/200 [00:17<00:00, 11.74it/s, train_acc=0.735, train_loss=0.53]
100%|██████████| 40/40 [00:01<00:00, 20.07it/s, val_acc=0.827, val_loss=0.383]
<<<<<< reach best val_acc : 0.8274999856948853 >>>>>>
================================================================================2022-07-10 20:06:35
Epoch 2 / 10
100%|██████████| 200/200 [00:16<00:00, 11.96it/s, train_acc=0.832, train_loss=0.391]
100%|██████████| 40/40 [00:02<00:00, 18.13it/s, val_acc=0.854, val_loss=0.317]
<<<<<< reach best val_acc : 0.8544999957084656 >>>>>>
================================================================================2022-07-10 20:06:54
Epoch 3 / 10
100%|██████████| 200/200 [00:17<00:00, 11.71it/s, train_acc=0.87, train_loss=0.313]
100%|██████████| 40/40 [00:02<00:00, 19.96it/s, val_acc=0.902, val_loss=0.239]
<<<<<< reach best val_acc : 0.9024999737739563 >>>>>>
================================================================================2022-07-10 20:07:13
Epoch 4 / 10
100%|██████████| 200/200 [00:16<00:00, 11.88it/s, train_acc=0.889, train_loss=0.265]
100%|██████████| 40/40 [00:02<00:00, 18.46it/s, val_acc=0.91, val_loss=0.216]
<<<<<< reach best val_acc : 0.9100000262260437 >>>>>>
================================================================================2022-07-10 20:07:32
Epoch 5 / 10
100%|██████████| 200/200 [00:17<00:00, 11.71it/s, train_acc=0.902, train_loss=0.239]
100%|██████████| 40/40 [00:02<00:00, 19.68it/s, val_acc=0.891, val_loss=0.279]
================================================================================2022-07-10 20:07:51
Epoch 6 / 10
100%|██████████| 200/200 [00:17<00:00, 11.75it/s, train_acc=0.915, train_loss=0.212]
100%|██████████| 40/40 [00:02<00:00, 19.52it/s, val_acc=0.908, val_loss=0.222]
================================================================================2022-07-10 20:08:10
Epoch 7 / 10
100%|██████████| 200/200 [00:16<00:00, 11.79it/s, train_acc=0.921, train_loss=0.196]
100%|██████████| 40/40 [00:02<00:00, 19.26it/s, val_acc=0.929, val_loss=0.187]
<<<<<< reach best val_acc : 0.9294999837875366 >>>>>>
================================================================================2022-07-10 20:08:29
Epoch 8 / 10
100%|██████████| 200/200 [00:17<00:00, 11.59it/s, train_acc=0.931, train_loss=0.175]
100%|██████████| 40/40 [00:02<00:00, 19.91it/s, val_acc=0.938, val_loss=0.187]
<<<<<< reach best val_acc : 0.9375 >>>>>>
================================================================================2022-07-10 20:08:49
Epoch 9 / 10
100%|██████████| 200/200 [00:17<00:00, 11.68it/s, train_acc=0.929, train_loss=0.178]
100%|██████████| 40/40 [00:02<00:00, 19.90it/s, val_acc=0.937, val_loss=0.181]
================================================================================2022-07-10 20:09:08
Epoch 10 / 10
100%|██████████| 200/200 [00:16<00:00, 11.84it/s, train_acc=0.937, train_loss=0.16]
100%|██████████| 40/40 [00:02<00:00, 19.91it/s, val_acc=0.937, val_loss=0.167]
该训练循环满足我所说的以上全部这些特性。
- 1,模块化:自下而上分成 StepRunner, EpochRunner, 和train_model 三级,结构清晰明了。
- 2,易修改:如果输入和label形式有差异(例如,输入可能组装成字典,或者有多个输入),仅需更改StepRunner就可以了,后面无需改动,非常灵活。
- 3,short-enough: 全部训练代码不到150行。
- 4,支持进度条:通过tqdm引入。
- 5,支持评估指标:引入torchmetrics库中的指标。
- 6,支持early-stopping:在train_model函数中指定 monitor、mode、patience即可。
以上训练循环也是我在eat_pytorch_in_20_days中使用的主要训练循环。该库目前已经获得3.3k 星星⭐️,大部分读者反馈还是挺好用的。
点击文末阅读原文,查看知乎原始回答,感觉不错的小伙伴可以给吃货本货一个赞同表示鼓励哦,谢谢大家。