下面开始加入test部分
先写入test部分代码
代码语言:javascript复制for x, label in cifar_test:
x, label = x.to(device), label.to(device)
logits = model(x)
pred = logits.armax(dim=1)
# 用argmax选出可能性最大的值的索引
为进行比对
定义正确率
写入对比
代码语言:javascript复制total_correct = torch.eq(pred, label).float().sum().item()
# torch.eq函数用于对比,同时要转为numpy数据
total_num = x.size(0)
再定义正确率并输出
代码语言:javascript复制acc = total_correct / total_num
print('acc:', acc)
可以加入模式切换
Model.train()和model.eval()
最终main.py文件为
代码语言:javascript复制import torch
from torchvision import datasets
# 引入pytorch、datasets工具包
from torchvision import transforms
# 引入数据变换工具包
from torch.utils.data import DataLoader
# 多线程数据读取
from LeNet5 import LeNet5
import torch.nn as nn
import torch.optim as optim
def main():
batchsz=32
# 这个batch_size数值不宜太大也不宜过小
cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
transforms.Resize((32, 32)),
# .Compose相当于一个数据转换的集合
# 进行数据转换,首先将图片统一为32*32
transforms.ToTensor()
# 将数据转化到Tensor中
]), download=True)
# 直接在datasets中导入CIFAR10数据集,放在"cifar"文件夹中
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
# 按照其要求,这里的参数需要有batch_size,
# 在该部分代码前面定义batch_size
# 再使数据加载的随机化
cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
# 通过.iter方法输出一个数据进行查看
# print('s.shape:', x.shape, 'label.shape:', label.shape)
# 输出shape进行查看
device = torch.device('cuda')
model = LeNet5().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
model.train()
for epoch in range(1000):
for batchidx, (x, label) in enumerate(cifar_train):
# batchidx代表了有多少个batch,
x, label = x.to(device), label.to(device)
logits = model(x)
loss = criteon(logits, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print(epoch, loss.item())
model.eval()
total_correct = 0
total_num = 0
for x, label in cifar_test:
x, label = x.to(device), label.to(device)
logits = model(x)
pred = logits.argmax(dim=1)
# 用argmax选出可能性最大的值的索引
# 进行比对
total_correct = torch.eq(pred, label).float().sum().item()
# torch.eq函数用于对比,同时要转为numpy数据
total_num = x.size(0)
acc = total_correct / total_num
print('acc:', acc)
输出为
可以看出正确率在逐渐上升