深度学习里面,请问有写train函数的模板吗?

2023-02-23 13:12:25 浏览数 (1)

知乎热门问题:深度学习里面,请问有写train函数的模板吗?

以下是 知乎用户 吃货本货 的回答。

老师,这题我会。

一般pytorch需要用户自定义训练循环,可以说有1000个pytorch用户就有1000种训练代码风格。 从实用角度讲,一个优秀的训练循环应当具备以下特点。

  • 代码简洁易懂 【模块化、易修改、short-enough】
  • 支持常用功能 【进度条、评估指标、early-stopping】

经过反复斟酌测试,我精心设计了仿照keras风格的pytorch训练循环。诸君且看。

代码语言:javascript复制
import os,sys,time
import numpy as np
import pandas as pd
import datetime 
from tqdm import tqdm 

import torch
from torch import nn 
from copy import deepcopy

def printlog(info):
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("n" "=========="*8   "%s"%nowtime)
    print(str(info) "n")

class StepRunner:
    def __init__(self, net, loss_fn,
                 stage = "train", metrics_dict = None, 
                 optimizer = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer = optimizer

    def step(self, features, labels):
        #loss
        preds = self.net(features)
        loss = self.loss_fn(preds,labels)

        #backward()
        if self.optimizer is not None and self.stage=="train": 
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

        #metrics
        step_metrics = {self.stage "_" name:metric_fn(preds, labels).item() 
                        for name,metric_fn in self.metrics_dict.items()}
        return loss.item(),step_metrics

    def train_step(self,features,labels):
        self.net.train() #训练模式, dropout层发生作用
        return self.step(features,labels)

    @torch.no_grad()
    def eval_step(self,features,labels):
        self.net.eval() #预测模式, dropout层不发生作用
        return self.step(features,labels)

    def __call__(self,features,labels):
        if self.stage=="train":
            return self.train_step(features,labels) 
        else:
            return self.eval_step(features,labels)

class EpochRunner:
    def __init__(self,steprunner):
        self.steprunner = steprunner
        self.stage = steprunner.stage

    def __call__(self,dataloader):
        total_loss,step = 0,0
        loop = tqdm(enumerate(dataloader), total =len(dataloader))
        for i, batch in loop: 
            loss, step_metrics = self.steprunner(*batch)
            step_log = dict({self.stage "_loss":loss},**step_metrics)
            total_loss  = loss
            step =1
            if i!=len(dataloader)-1:
                loop.set_postfix(**step_log)
            else:
                epoch_loss = total_loss/step
                epoch_metrics = {self.stage "_" name:metric_fn.compute().item() 
                                 for name,metric_fn in self.steprunner.metrics_dict.items()}
                epoch_log = dict({self.stage "_loss":epoch_loss},**epoch_metrics)
                loop.set_postfix(**epoch_log)

                for name,metric_fn in self.steprunner.metrics_dict.items():
                    metric_fn.reset()
        return epoch_log


def train_model(net, optimizer, loss_fn, metrics_dict, 
                train_data, val_data=None, 
                epochs=10, ckpt_path='checkpoint.pt',
                patience=5, monitor="val_loss", mode="min"):

    history = {}

    for epoch in range(1, epochs 1):
        printlog("Epoch {0} / {1}".format(epoch, epochs))

        # 1,train -------------------------------------------------  
        train_step_runner = StepRunner(net = net,stage="train",
                loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict),
                optimizer = optimizer)
        train_epoch_runner = EpochRunner(train_step_runner)
        train_metrics = train_epoch_runner(train_data)

        for name, metric in train_metrics.items():
            history[name] = history.get(name, [])   [metric]

        # 2,validate -------------------------------------------------
        if val_data:
            val_step_runner = StepRunner(net = net,stage="val",
                loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict))
            val_epoch_runner = EpochRunner(val_step_runner)
            with torch.no_grad():
                val_metrics = val_epoch_runner(val_data)
            val_metrics["epoch"] = epoch
            for name, metric in val_metrics.items():
                history[name] = history.get(name, [])   [metric]

        # 3,early-stopping -------------------------------------------------
        arr_scores = history[monitor]
        best_score_idx = np.argmax(arr_scores) if mode=="max" else np.argmin(arr_scores)
        if best_score_idx==len(arr_scores)-1:
            torch.save(net.state_dict(),ckpt_path)
            print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,
                 arr_scores[best_score_idx]),file=sys.stderr)
        if len(arr_scores)-best_score_idx>patience:
            print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(
                monitor,patience),file=sys.stderr)
            break 
        net.load_state_dict(torch.load(ckpt_path))

    return pd.DataFrame(history)

使用方法如下:

代码语言:javascript复制
from torchmetrics import Accuracy

loss_fn = nn.BCEWithLogitsLoss()
optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   
metrics_dict = {"acc":Accuracy()}

dfhistory = train_model(net,
    optimizer,
    loss_fn,
    metrics_dict,
    train_data = dl_train,
    val_data= dl_val,
    epochs=10,
    patience=5,
    monitor="val_acc", 
    mode="max")

疗效如下:

代码语言:javascript复制
================================================================================2022-07-10 20:06:16
Epoch 1 / 10

100%|██████████| 200/200 [00:17<00:00, 11.74it/s, train_acc=0.735, train_loss=0.53]
100%|██████████| 40/40 [00:01<00:00, 20.07it/s, val_acc=0.827, val_loss=0.383]
<<<<<< reach best val_acc : 0.8274999856948853 >>>>>>

================================================================================2022-07-10 20:06:35
Epoch 2 / 10

100%|██████████| 200/200 [00:16<00:00, 11.96it/s, train_acc=0.832, train_loss=0.391]
100%|██████████| 40/40 [00:02<00:00, 18.13it/s, val_acc=0.854, val_loss=0.317]
<<<<<< reach best val_acc : 0.8544999957084656 >>>>>>

================================================================================2022-07-10 20:06:54
Epoch 3 / 10

100%|██████████| 200/200 [00:17<00:00, 11.71it/s, train_acc=0.87, train_loss=0.313]
100%|██████████| 40/40 [00:02<00:00, 19.96it/s, val_acc=0.902, val_loss=0.239]
<<<<<< reach best val_acc : 0.9024999737739563 >>>>>>

================================================================================2022-07-10 20:07:13
Epoch 4 / 10

100%|██████████| 200/200 [00:16<00:00, 11.88it/s, train_acc=0.889, train_loss=0.265]
100%|██████████| 40/40 [00:02<00:00, 18.46it/s, val_acc=0.91, val_loss=0.216]
<<<<<< reach best val_acc : 0.9100000262260437 >>>>>>

================================================================================2022-07-10 20:07:32
Epoch 5 / 10

100%|██████████| 200/200 [00:17<00:00, 11.71it/s, train_acc=0.902, train_loss=0.239]
100%|██████████| 40/40 [00:02<00:00, 19.68it/s, val_acc=0.891, val_loss=0.279]

================================================================================2022-07-10 20:07:51
Epoch 6 / 10

100%|██████████| 200/200 [00:17<00:00, 11.75it/s, train_acc=0.915, train_loss=0.212]
100%|██████████| 40/40 [00:02<00:00, 19.52it/s, val_acc=0.908, val_loss=0.222]

================================================================================2022-07-10 20:08:10
Epoch 7 / 10

100%|██████████| 200/200 [00:16<00:00, 11.79it/s, train_acc=0.921, train_loss=0.196]
100%|██████████| 40/40 [00:02<00:00, 19.26it/s, val_acc=0.929, val_loss=0.187]
<<<<<< reach best val_acc : 0.9294999837875366 >>>>>>

================================================================================2022-07-10 20:08:29
Epoch 8 / 10

100%|██████████| 200/200 [00:17<00:00, 11.59it/s, train_acc=0.931, train_loss=0.175]
100%|██████████| 40/40 [00:02<00:00, 19.91it/s, val_acc=0.938, val_loss=0.187]
<<<<<< reach best val_acc : 0.9375 >>>>>>

================================================================================2022-07-10 20:08:49
Epoch 9 / 10

100%|██████████| 200/200 [00:17<00:00, 11.68it/s, train_acc=0.929, train_loss=0.178]
100%|██████████| 40/40 [00:02<00:00, 19.90it/s, val_acc=0.937, val_loss=0.181]

================================================================================2022-07-10 20:09:08
Epoch 10 / 10

100%|██████████| 200/200 [00:16<00:00, 11.84it/s, train_acc=0.937, train_loss=0.16] 
100%|██████████| 40/40 [00:02<00:00, 19.91it/s, val_acc=0.937, val_loss=0.167]

该训练循环满足我所说的以上全部这些特性。

  • 1,模块化:自下而上分成 StepRunner, EpochRunner, 和train_model 三级,结构清晰明了。
  • 2,易修改:如果输入和label形式有差异(例如,输入可能组装成字典,或者有多个输入),仅需更改StepRunner就可以了,后面无需改动,非常灵活。
  • 3,short-enough: 全部训练代码不到150行。
  • 4,支持进度条:通过tqdm引入。
  • 5,支持评估指标:引入torchmetrics库中的指标。
  • 6,支持early-stopping:在train_model函数中指定 monitor、mode、patience即可。

以上训练循环也是我在eat_pytorch_in_20_days中使用的主要训练循环。该库目前已经获得3.3k 星星⭐️,大部分读者反馈还是挺好用的。

点击文末阅读原文,查看知乎原始回答,感觉不错的小伙伴可以给吃货本货一个赞同表示鼓励哦,谢谢大家。

0 人点赞