文章目录
pytorch 图像分类实例《1》
代码语言:javascript
复制# -*- coding:utf-8 -*-
# /usr/bin/python
'''
@Author : Errol
@Describe:
@Evn :
@Date : -
'''
import torch
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
# 下载数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root = './data',train = True, download = True, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,shuffle = True, num_workers =2)
testset = torchvision.datasets.CIFAR10(root = './data',train = False, download = True, transform = transform)
testloader = torch.utils.data.DataLoader(testse