PyTorch入门笔记-手写数字实战02

2021-03-28 22:42:39 浏览数 (1)

构建模型

自定义一个模型可以通过继承 torch.nn.Moudle 类来实现,在 __init__ 构造函数中来定义声明模型中的各个层,在 forward 方法中构建各个层的连接关系实现模型前向传播的过程。在 PyTorch 这种高级的深度学习框架中帮我们实现了很多常见的网络层以及激活函数。PyTorch 中的网络层通常在 torch.nn 包下,而激活函数通常在 torch.nn.functional 包下。

代码语言:txt复制
from torch import nn
from torch.nn import functional as F

# 2. 构建模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # 使用PyTorch提供的Linear线性层
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
        
    def forward(self, X):
        # X: [batch_size, 1, 28, 28]
        # H1 = relu(XW1   b1)
        X = F.relu(self.fc1(X))
        # H2 = relu(H1W2   b2)
        X = F.relu(self.fc2(X))
        # H3 = H2W3   b3
        X = self.fc3(X)
        
        return X

此时构建的是一个四层的全连接神经网络,由于我们将 (28 x 28) 的手写数字图片像素矩阵打平成了 (784, ) 的特征向量,并且将对应的数字标签转换成了 one-hot 十个维度的向量,因此全连接神经网络的输入层和输出层的节点数都是固定的分别为 784 和 10。「通常我们把输入层和输出层之外的层称为隐藏层,而隐藏层的层数以及每一层的节点个数都是需要我们人为指定的超参数。」

训练模型

代码语言:txt复制
import utils # 加载我们自己写的工具类
from torch import optim # 其中包含各种优化算法

epochs = 3 # 迭代的轮数
train_loss = [] # 用于存储训练过程中的损失值,方便可视化

net = Net() # 实例化模型
# SGD随机梯度下降法
optimizer = optim.SGD(net.parameters(), lr = 0.01, momentum = 0.9)

for epoch in range(epochs): # 对整个数据集迭代3遍
    # 每一次for循环都会获取batch_size个样本
    for batch_idx, (X_train, label_train) in enumerate(train_loader):
        # X_train.shape = torch.Size([batch_size, 1, 28, 28])
        # label_train.shape = torch.Size([batch_size])
        # 根据前面的学习,需要:
        #   1. X_train打平成[batch_size, 784](X_train.shape[0] = batch_size)
        X_train = X_train.reshape(X_train.shape[0], 28 * 28)
        #   2. label_train数字编码转换为one_hot编码
        y_train = utils.one_hot(label_train)

        # 前向传播过程
        # X_train: [batch_size, 784] -> out: [batch_size, 10]
        out = net(X_train)

        # 计算当前损失值,由于输出节点没有使用任何激活函数
        # 因此使用简单的均方差MSE
        loss = F.mse_loss(out, y_train)
        # 由于PyTorch会把计算的梯度值进行累加,因此每次循环需要将梯度值置为0
        optimizer.zero_grad()
        # 计算loss关于net.parameters()的梯度
        loss.backward()
        # 使用梯度下降算法更新net.parameters()参数值:
        #   theta' = theta - 学习率 * 梯度
        optimizer.step()

        train_loss.append(loss) # 为了可视化,保存当前loss值

        if batch_idx % 10 == 0: # 每隔10个batch打印一次
            print("epoch: ", epoch, "batch_idx: ", batch_idx, "loss: ", loss.item())
代码语言:txt复制
epoch:  0 batch_idx:  0 loss:  0.12745174765586853
epoch:  0 batch_idx:  10 loss:  0.09659640491008759
...
epoch:  2 batch_idx:  100 loss:  0.03306906297802925
epoch:  2 batch_idx:  110 loss:  0.03607671707868576

每一步都有非常详细的代码注释,这里不再过多赘述。「torch.optim 包中实现了各种优化算法,SGD 是随机梯度下降法」简单来看看 torch.optim.SGD(net.parameters(), lr = 0.01, momentum = 0.9) 中的三个参数:

  • net.parameters():模型网络中的所有待优化参数,由于使用 PyTorch 提供的 Linear 层,其中的优化参数都为我们定义好了。如果使用我们自己定义的层,需要在定义待优化参数的时候将 required_grad 参数指定为 True;
  • lr = 0.01:指定学习率为 0.01;
  • momentum = 0.9:动量因子,简单来说给梯度一个冲量帮助跳出局部极小值点或者一些梯度等于 0 的点。具体可以看推荐阅读中的文章;

为了可视化将训练过程中的 loss 值保存在 train_loss 列表中,只需要调用我们自己实现的工具类中的 utils.plot_curve(train_loss) 方法即可绘制训练过程中的 loss 值曲线。

代码语言:txt复制
utils.plot_curve(train_loss)

测试模型

接下来用测试集来对训练好的模型进行评估。评估模型非常简单,只需要将测试集中的手写数字图片矩阵打平之后输入到训练好的模型中,对于每个测试集样本,模型都会输出一个十维的向量,使用 argmax 方法输出十维向量 10 个值中最大值所在位置的索引。

代码语言:txt复制
# 使用测试集来评估训练好的模型
total_correct = 0
for X_test, label_test in test_loader:
    # 打平X_test
    X_test = X_test.reshape(X_test.shape[0], 28 * 28)
    out = net(X_test) # 训练好的模型前向传播过程
    # 获取out中[batch_size, 10]中10个值的最大值所在位置的索引
    # out:[batch_size, 10] => pred: [b]
    pred = out.argmax(dim = 1)
    # 获取预测正确的样本个数,由于pred为最大值所在位置的索引,
    # 因此不需要将label_test转换为one_hot编码
    # 当tensor为标量的时候,tensor.item()可以将其转换为ndarray类型
    correct = pred.eq(label_test).sum().float().item()
    total_correct  = correct # 所有预测正确的样本个数

# 获取整个测试集的样本个数
total_num = len(test_loader.dataset)
acc = total_correct / total_num
print("test acc: ", acc)

最终模型的准确率为:0.8837。仅仅是一个四层网络就能够达到 88% 的准确率,可见深度学习的强大,当然这并不是 MNIST 手写数字识别所能达到的最高准确率,我们可以调整神经网络的层数、每层的神经元个数或者使用卷积神经网络等等以达到更高的准确率。

References: 1. 龙良曲深度学习与PyTorch入门实战:https://study.163.com/course/introduction/1208894818.htm

原文地址:https://mp.weixin.qq.com/s/svZd71Y2-G8qwGLviJ3YIQ

0 人点赞