PyTorch入门笔记-手写数字实战01

2021-03-28 22:37:50 浏览数 (1)

下面来简单回顾上一小节的嵌套非线性模型:

  • H_1 = relu(XW_1 b_1)
  • H_2 = relu(H1W_2 b_2)
  • H_3 = f(H_2W_3 b_3), 模型最后一层的激活函数不会是 relu 激活函数,需要根据你的具体任务来选择合适的激活函数。比如使用二分类的 Sigmoid 或多分类的 SoftMax(当然多个二分类也可以用于处理多分类)。由于这里只是简单的演示整个训练流程,所以为了简单本小节最后一层不添加任何激活函数。

对 MNIST 手写数字识别进行分类大致分为四个步骤,这四个步骤也是训练大多数深度学习模型的基本步骤:

  • 加载数据集(Load data)
  • 构建模型(Build Model)
  • 训练(Train)
  • 测试(Test)

不过在这之前我们需要构建一个 utils.py 文件,其中包含着三个工具方法:

  • plot_curve(loss_list) 方法绘制损失函数曲线;
  • plot_image(x, label, name)方法显示 6 张手写数字图片以及对应的数字标签;
  • one_hot(label, depth = 10)方法将 0~9 的数字编码标签转换为 one-hot 编码的标签。比如将数字编码 5 转换为 one-hot 编码为 [0,0,0,0,1,0,0,0,0,0](由于此时假设为十个类别,因此 one-hot 编码后的向量维度为 10 维)。
代码语言:txt复制
import torch
from matplotlib import pyplot as plt

def plot_curve(loss_list):
    """
    根据存放loss值的列表绘制曲线
    """
    plt.plot(range(len(loss_list)), loss_list, color = 'blue')
    # 添加图例并放置在右上角
    plt.legend(['train_loss'], loc = 'upper right')
    plt.xlabel('step') # 设置横坐标轴名称
    plt.ylabel('train_loss') # 设置纵坐标轴名称
    plt.show()

def plot_image(x, label, name):
    """
    显示6张手写数字图片以及对应的数字标签
    """
    for i in range(6):
        plt.subplot(2, 3, i   1)
        plt.tight_layout()
        plt.imshow(x[i][0] * 0.3081   0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth = 10):
    '''
    将数字编码标签label转换为one-hot编码y
    '''
    y = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    y.scatter_(dim = 1, index = idx, value = 1)
    return y

加载数据集

MNIST 是比较重要和经典的数据集,目前常用的机器学习和深度学习框架都内置了 MNIST 数据集,通过几行代码就可以自动下载、管理以及加载 MNIST 数据集。基于 PyTorch 有很多工具集,比如:处理自然语言的 torchtext,处理音频的 torchaudio 和 处理图像视频的 torchvision,这些工具集可以独立于 PyTorch 的使用。MNIST 数据集属于图像,我们可以在 torchvision.datasets 包中加载 MNIST。「加载的 MNIST 数据集是 ndarray 数组类型,因此我们需要将其转换成 Tensor。实验证明输入数据在 0 附近均匀分布,神经网络模型会有所提升(在本小节的神经网络模型架构下,对数据进行标准化准确率能够提升 10%),因此我们还需要对 MNIST 数据集进行标准化的转换,torchvision.transforms 包提供了这些转换方法。」

代码语言:txt复制
import torchvision

train_data = torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ]))

print(len(train_data)) # 60000
# 训练集中的第1张手写数字图片以及对应的标签
X_train_0, label_train_0 = train_data[0]
print(X_train_0.shape) # torch.Size([1, 28, 28])
print(label_train_0) # 5

在 torchvision.datasets 中有很多类似 MNIST 的数据集,下面来简单介绍 torchvision.datasets.MNIST 中的一些参数:

  • 'mnist_data':MNIST 数据集所在的文件夹,我直接设置在当前路径。如果你也传入 'mnist_data',你会在当前路径下发现一个 mnist_data 的文件夹;
  • train = True:可选参数。如果设置为 True,则从 ./mnist_data/MNIST/processed/training.pt 中加载训练集(使用 len(train_data) 可以看出共有 60000 张手写数字图片)。如果设置为 False,则从 ./mnist_data/MNIST/processed/test.pt 中加载测试集;
  • download = True:可选参数。如果设置为 True,且路径下没有 MNIST 数据集,则会从网络上下载 MNIST 数据集,如果路径下已经存在 MNIST 数据集,则不会再次下载;
  • transform = torchvision.transforms.Compose:transform 进行数据的预处理操作:
    • ToTensor:将 ndarray 数组转换为 Tensor 数据类型;
    • Normalize:进行数据的标准化,即减去均值除以方差,此时均值 0.1307 和方差 0.3081 是 MNIST 数据集计算好的数据,直接使用即可;

加载完了 MNIST 数据集中的训练集,我们可以设置 train = False 来加载 10000 张测试集。

代码语言:txt复制
import torchvision

test_data = torchvision.datasets.MNIST('mnist_data', train = False, download = True,
                                       transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize(
                                               (0.1307,), (0.3081,))
                               ]))

print(len(test_data)) # 10000
# 测试集中的第1张手写数字图片以及对应的标签
X_test_0, label_test_0 = test_data[0]
print(X_test_0.shape) # torch.Size([1, 28, 28])
print(label_test_0) # 7

至此 60000 张训练集以及 10000 张测试集都加载进来了,不过我们通常使用更为方便的数据集加载器 DataLoader,DataLoader 结合了数据集和取样器,提供了多个线程处理数据集,并且里面提供了很多方便处理数据集的功能。DataLoader 在 torch.utils.data 包下。

代码语言:txt复制
import torch
import utils # 加载我们自己写的工具类

batch_size = 512

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size = batch_size, # batch_size
                                           shuffle = True) # 是否打乱数据集
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size = batch_size,
                                          # 测试集只用于验证模型性能不需要打乱数据集
                                          shuffle = False) 
# 迭代器加载数据集,每次都加载batch_size个
# X: [batch_size, channel, width, hight]
# label: 数字编码
X, label = next(iter(train_loader))
print(X.shape, label.shape, X.min(), label.max())
utils.plot_image(X, label, 'image sample')
代码语言:txt复制
torch.Size([512, 1, 28, 28]) torch.Size([512]) tensor(-0.4242) tensor(9)

References: 1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm

原文地址:https://mp.weixin.qq.com/s/JTMcPCUL-F8kd3CRnUvOUg

0 人点赞