注:本文选自机械工业出版社出版的《从零开始构建深度前馈神经网络(Python TensorFlow 2.x)》一书,略有改动。经出版社授权刊登于此。
MNIST是经典的手写数字(handwritten digits)图像数据集。其中,训练数据集(training set,简称训练集)包含60 000个样本,测试数据集(test set,简称测试集)包含10 000个样本。
图1展示了MNIST训练集的前15个样本。每幅图像代表一个手写数字,每个方框下方的数字是这个图像对应的标签(label)。
一幅图像及其对应的标签构成了一个输入/输出对,例如,图1左上角的图像与其正下方的5构成了一个输入/输出对,我们把这个输入/输出对称为一个样本(sample/example)。输入通常由特征向量(feature vector)表示。例如,图1左上角的图像的原始数据是一个784维的特征向量。
图1 MNIST训练集的前15个样本
本章将训练一个k-NN模型,其输入是784维的特征向量,输出为相应标签的预测值,即,给定任意一个表示手写数字的784维向量,预测它是0~9中的哪一个。
使用TensorFlow加载MNIST
先来看一段示例代码:
%matplotlib inline
import matplotlib.pyplot as plt
from tensorflow import keras
def ds_imshow(im_data, im_label):
plt.figure(figsize=(10,10))
for i in range(len(im_data)):
plt.subplot(5,5,i 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(im_data[i], cmap=plt.cm.binary)
plt.xlabel(im_label[i])
plt.show()
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
ds_imshow(x_train[:15].reshape((15,28,28)), y_train[:15])
上面的代码在导入必要的模块后定义了一个名为ds_imshow()的函数,然后加载数据集,最后将加载的数据作为参数并调用ds_imshow()函数显示图像。
其中,%matplotlib inline需要在新建Notebook后且首次调用plt.show()之前运行,仅需运行一次即可作用于整个Notebook。
ds_imshow()函数将传入的NumPy数组显示为图像,参数im_data用于接收图像数组,每幅图像表示一个样本特征,im_label是与之对应的标签。
keras.datasets.mnist.load_data()方法用于加载数据集,首次运行时需要用十几秒到几分钟的时间进行远程下载,再次使用时将从本地加载。
数组x_train表示训练集中60 000个像素为28×28的手写数字灰度图像,y_train表示与之对应的标签集合;x_test表示测试集中10 000个像素为28×28的手写数字灰度图像,y_test表示与之对应的标签集合。
示例中的最后一行代码是调用ds_imshow()函数将训练集中的前15个样本绘制为图像,并在每幅图像的正下方显示与之对应的标签。例如,y_train[0]为5,表示与之对应的x_train[0]是手写数字5的灰度图像,即位于图2左上角的样本。
Keras默认是将数据集文件(mnist.npz)存储在用户家目录下的.kerasdatasets中。在Windows运行窗口中输入以下命令,如图2所示。
%HOMEPATH%.kerasdatasets
回车或单击OK按钮即可以查看该目录。
使用scikit-learn加载MNIST
与keras.datasets.mnist.load_data()方法类似,scikit-learn也提供了加载MNIST数据集的方法,通过以下代码可以导入datasets模块。
from sklearn import datasets
以下两行代码用于加载MNIST数据集,并将数据集中的前15个样本绘制为图像:
mnist = datasets.load_digits()
ds_imshow(mnist.data[:15].reshape((15,8,8)), mnist.target[:15])
程序运行结果如图3所示。
细心的读者可能已经发现了一个问题,MNIST的每个样本的像素是28×28,而代码中却将每个样本调整为(8,8)。这是因为datasets.load_digits()加载的样本像素并非是28×28,而是8×8,所以图像显得比较模糊。
尽管scikit-learn也提供了以下方法用于加载28×28像素版本的MNIST:
from sklearn.datasets import fetch_openml
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=
False)
但是偶尔会遭遇加载缓慢甚至失败。因此建议读者使用keras.datasets.mnist.load_data()方法加载28×28像素版本的MNIST。
图3 运行结果