Pytorch-神经网络中测试部分的编写

2019-11-17 22:05:57 浏览数 (1)

在进行pytorch训练后,需要进行测试部分的编写。

首先看一个train和test的波动实例

首先上图可视化结果来看,蓝线是train的正确率,随着运行次数的增加随之升高。而下图中的蓝线为train的loss过程,也随之降低。由图来看貌似训练过程良好,但实际被骗啦。

这是里面的over fitting在作怪,随着train的进行,里面的sample被其所记忆,导致构建的网络很肤浅,无法适应一些复杂的环境。

若想缓解这种情况,在train的同时做test。

由黄线test结果可看到,其总体趋势与train相一致,但呈现出的波动较大。但可明显注意到在上图的后半期test的正确率不再变化,且下图中的loss也很大。

总之,train过程并不是越多越好,而是取决于所采用的架构、函数、足够的数据才能取得较好的效果。

那么test部分该如何编写呢

本代码要实现一个验证的功能

原本要进行的cross entropy loss操作的结果,我们将Logits提出进行softmax操作,再进行argmax得到label,与cross entropy loss的结果进行验证查看正确与否。

代码语言:javascript复制
import torch
import torch.nn.functional as F

logits = torch.rand(4, 10)
# 先定义一个logits,物理意义为有4张图片,每张图片有10维的数据
pred = F.softmax(logits, dim=1)
# 这里在10维度的输出值上进行softmax,
pred_label = pred.argmax(dim=1)
print(pred_label)
logits_label = logits.argmax(dim=1)
print(logits_label)

输出为

代码语言:javascript复制
tensor([1, 4, 7, 0])
tensor([1, 4, 7, 0])

假设已知真实的label值,下面进行验证

代码语言:javascript复制
true_label = torch.tensor([4, 6, 7, 9])
# 假定真实的label值为4, 6, 7, 9
correct = torch.eq(pred_label, true_label)
# 使用.eq函数来计算其正确率
print(correct.sum().float().item()/4)

输出为

代码语言:javascript复制
0.25

正确率为25%,表明有一个预测正确。

那么何时使用test呢?

(1)train多个batch后进行一次test。或(2)每一个循环后进行一次test。

当具体到神经网络中时,变为

代码语言:javascript复制
test_loss = 0
correct = 0
# 先设定两个初始值均为0
for data, target in test_loader:
    data = data.view(-1, 28*28)
    data, target = data.to(device), target.to(device)
    logits = net(data)
    test_loss  = criteon(logits, target).item()
    
    pred = logits.argmax(dim=1)
    correct  = pred.eq(target).float().sum().item()
    
test_loss /= len(test_loss.dataset)

0 人点赞