深度学习里面有没有支持Multi-GPU-DDP模式的pytorch模型训练代码模版?

2023-02-23 11:57:45 浏览数 (1)

一般pytorch需要用户自定义训练循环,可以说有1000个pytorch用户就有1000种训练代码风格。

从实用角度讲,一个优秀的训练循环应当具备以下特点。

代码简洁易懂 【模块化、易修改、short-enough】

支持常用功能 【进度条、评估指标、early-stopping】

经过反复斟酌测试,我精心设计了仿照keras风格的pytorch训练循环,完全满足以上条件。

该方案在知乎受到许多读者喜爱,目前为止获得了超过600个赞。

知乎完整回答链接:《深度学习里面,请问有写train函数的模板吗?》

https://www.zhihu.com/question/523869554/answer/2633479163

以上pytorch模型训练模版也是我开源的一个pytorch模型训练工具 torchkeras库的核心代码。

https://github.com/lyhue1991/torchkeras

铛铛铛铛,torchkeras加入新功能啦。

最近,通过引入HuggingFace的accelerate库的功能,torchkeras进一步支持了 多GPU的DDP模式和TPU设备上的模型训练。

这里给大家演示一下,非常强大和丝滑。

公众号算法美食屋后台回复关键词:训练模版,获取本文B站视频演示和notebook源代码。

代码语言:javascript复制
#从git安装最新的accelerate仓库
!pip install git https://github.com/huggingface/accelerate

一,torchkeras源码解析

torchkeras的核心代码在 下面这个文件中。

https://github.com/lyhue1991/torchkeras/blob/master/torchkeras/kerasmodel.py

代码语言:javascript复制
import sys,datetime
from tqdm import tqdm 
from copy import deepcopy
import numpy as np
import pandas as pd
import torch
from accelerate import Accelerator

def colorful(obj,color="red", display_type="plain"):
    color_dict = {"black":"30", "red":"31", "green":"32", "yellow":"33",
                    "blue":"34", "purple":"35","cyan":"36",  "white":"37"}
    display_type_dict = {"plain":"0","highlight":"1","underline":"4",
                "shine":"5","inverse":"7","invisible":"8"}
    s = str(obj)
    color_code = color_dict.get(color,"")
    display  = display_type_dict.get(display_type,"")
    out = '33[{};{}m'.format(display,color_code) s '33[0m'
    return out 

class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
    
    def __call__(self, batch):
        features,labels = batch 
        
        #loss
        preds = self.net(features)
        loss = self.loss_fn(preds,labels)

        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
        all_preds = self.accelerator.gather(preds)
        all_labels = self.accelerator.gather(labels)
        all_loss = self.accelerator.gather(loss).sum()
            
        #metrics
        step_metrics = {self.stage "_" name:metric_fn(all_preds, all_labels).item() 
                        for name,metric_fn in self.metrics_dict.items()}
        
        return all_loss.item(),step_metrics

class EpochRunner:
    def __init__(self,steprunner):
        self.steprunner = steprunner
        self.stage = steprunner.stage
        self.steprunner.net.train() if self.stage=="train" else self.steprunner.net.eval()
        self.accelerator = self.steprunner.accelerator
        
    def __call__(self,dataloader):
        total_loss,step = 0,0
        loop = tqdm(enumerate(dataloader), 
                    total =len(dataloader),
                    file=sys.stdout,
                    disable=not self.accelerator.is_local_main_process,
                    ncols = 100
                   )
        
        for i, batch in loop: 
            if self.stage=="train":
                loss, step_metrics = self.steprunner(batch)
            else:
                with torch.no_grad():
                    loss, step_metrics = self.steprunner(batch)
                    
            step_log = dict({self.stage "_loss":loss},**step_metrics)
            total_loss  = loss
            step =1
            
            if i!=len(dataloader)-1:
                loop.set_postfix(**step_log)
            else:
                epoch_loss = total_loss/step
                epoch_metrics = {self.stage "_" name:metric_fn.compute().item() 
                                 for name,metric_fn in self.steprunner.metrics_dict.items()}
                epoch_log = dict({self.stage "_loss":epoch_loss},**epoch_metrics)
                loop.set_postfix(**epoch_log)
                for name,metric_fn in self.steprunner.metrics_dict.items():
                    metric_fn.reset()
        return epoch_log
    
class KerasModel(torch.nn.Module):
    def __init__(self,net,loss_fn,metrics_dict=None,optimizer=None,lr_scheduler = None):
        super().__init__()
        self.net,self.loss_fn = net, loss_fn
        self.metrics_dict = torch.nn.ModuleDict(metrics_dict) 
        self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(
            self.net.parameters(), lr=1e-3)
        self.lr_scheduler = lr_scheduler

    def forward(self, x):
        return self.net.forward(x)

    def fit(self, train_data, val_data=None, epochs=10,ckpt_path='checkpoint.pt',
            patience=5, monitor="val_loss", mode="min", mixed_precision='no'):
        
        accelerator = Accelerator(mixed_precision=mixed_precision)
        device = str(accelerator.device)
        device_type = '


	

0 人点赞