文章目录
pytorch 搭建BP网络
代码语言:javascript
复制# -*- coding:utf-8 -*-
# /usr/bin/python
'''
@Author : Errol
@Describe:
@Evn :
@Date : -
'''
import os
import torch
import numpy as np
from torchvision.datasets import mnist
from torch import nn
from torch.autograd import Variable
# 数据标准化处理
def data_std(x):
x = np.array(x,dtype='float32')/255
x = (x-0.5)/0.5 #标准化
x = x.reshape((-1,))# 拉平
x = torch.from_numpy(x)
return x
# 数据准备
train_set = mnist.MNIST('./data',train=True,transform=data_std,download= True)
test_set = mnist.MNIST('./data',train=False,transform=data_std,download= True)
print('train_set',train_set,type(train_set))
print('test_set',test_set,type(test_set))
a,a_label = train_