PyTorch 实战(模型训练、模型加载、模型测试)

2022-05-12 08:54:27 浏览数 (1)

  • 本次将一个使用Pytorch的一个实战项目,记录流程:自定义数据集->数据加载->搭建神经网络->迁移学习->保存模型->加载模型->测试模型
  • 自定义数据集 参考我的上一篇博客:自定义数据集处理
  • 数据加载 默认小伙伴有对深度学习框架有一定的了解,这里就不做过多的说明了。 好吧,还是简单的说一下吧: 我们在做好了自定义数据集之后,其实数据的加载和MNSIT 、CIFAR-10 、CIFAR-100等数据集的都是相似的,过程如下所示:
代码语言:txt复制
* 导入必要的包import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader加载数据
可以发现和MNIST  、CIFAR的加载基本上是一样的train_db = Pokemon('pokeman', 224, mode='train')
val_db = Pokemon('pokeman', 224, mode='val')
test_db = Pokemon('pokeman', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
                          num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
代码语言:txt复制
def evalute(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct  = torch.eq(pred, y).sum().float().item()
    return correct / total
def main():
    model = ResNet18(5).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):
        for step, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            model.train()
            logits = model(x)
            loss = criteon(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step  = 1
        if epoch % 1 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                viz.line([val_acc], [global_step], win='val_acc', update='append')
    print('best acc:', best_acc, 'best epoch:', best_epoch)
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')
    test_acc = evalute(model, test_loader)
  • 迁移学习 提升模型的准确率:
代码语言:txt复制
    # model = ResNet18(5).to(device)
    trained_model=resnet18(pretrained=True)  # 此时是一个非常好的model
    model = nn.Sequential(*list(trained_model.children())[:-1],  # 此时使用的是前17层的网络 0-17  *:随机打散
                          Flatten(),
                          nn.Linear(512,5)
                          ).to(device)
    # x=torch.randn(2,3,224,224)
    # print(model(x).shape)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
  • 保存、加载模型 pytorch保存模型的方式有两种: 第一种:将整个网络都都保存下来 第二种:仅保存和加载模型参数(推荐使用这样的方法)
代码语言:txt复制
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
代码语言:txt复制
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

可以看到这是我保存的模型:

代码语言:txt复制
其中best.mdl是第二中方法保存的
代码语言:txt复制
model.pkl则是第一种方法保存的
在这里插入图片描述在这里插入图片描述
  • 测试模型 这里是训练时的情况 在这里插入图片描述在这里插入图片描述 看这个数据准确率还是不错的,但是还是需要实际的测试这个模型,看它到底学到东西了没有,接下来简单的测试一下:
代码语言:txt复制
import torch
from PIL import Image
from torchvision import transforms
device = torch.device('cuda')
transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],
                                 std=[0.229,0.224,0.225])
                            ])
def prediect(img_path):
    net=torch.load('model.pkl')
    net=net.to(device)
    torch.no_grad()
    img=Image.open(img_path)
    img=transform(img).unsqueeze(0)
    img_ = img.to(device)
    outputs = net(img_)
    _, predicted = torch.max(outputs, 1)
    # print(predicted)
    print('this picture maybe :',classes[predicted[0]])
if __name__ == '__main__':
    prediect('./test/name.jpg')

实际的测试结果

在这里插入图片描述在这里插入图片描述
在这里插入图片描述在这里插入图片描述

效果还是可以的,完整的代码:

https://github.com/huzixuan1/Loader_DateSet

数据集下载链接:

https://pan.baidu.com/s/12-NQiF4fXEOKrXXdbdDPCg

由于笔者能力水平有限,在表述上可能有些不准确;有问题可以联系在这里插入图片描述在这里插入图片描述

0 人点赞