pytorch 搭建BP网络

2021-01-14 11:47:10 浏览数 (1)

文章目录

  • pytorch 搭建BP网络

pytorch 搭建BP网络

代码语言:javascript复制
# -*- coding:utf-8 -*-
# /usr/bin/python
'''
@Author  :  Errol 
@Describe:  
@Evn     :  
@Date    :   - 
'''
import os
import torch
import numpy as np
from torchvision.datasets import mnist
from torch  import nn
from torch.autograd import Variable

# 数据标准化处理
def data_std(x):
    x = np.array(x,dtype='float32')/255
    x = (x-0.5)/0.5 #标准化
    x = x.reshape((-1,))# 拉平
    x  = torch.from_numpy(x)
    return x

# 数据准备
train_set = mnist.MNIST('./data',train=True,transform=data_std,download= True)
test_set = mnist.MNIST('./data',train=False,transform=data_std,download= True)
print('train_set',train_set,type(train_set))
print('test_set',test_set,type(test_set))
a,a_label = train_

0 人点赞