机器学习-使用TF.learn识别手写的数字图像

2019-09-09 19:19:40 浏览数 (2)

背景介绍

我们今天要解决的问题是从MNIST数据集中分类手写数字,并且写一个简单的分类器,被认为是计算机视觉的Hello World。现在MNIST是一个多类别的分类问题。给出一个数字的图像,我们的工作将预测它是哪一个数字,我们使用Jputer Notebook编写相关代码。首先是介绍的内容的概述,展示如何下载数据集并可视化图像。接下来,我们将训练一个分类器,评估它,并用它来预测新的图像。然后我们将可视化分类器学习的权重获得对它如何在底层工作的直觉。让我们从安装TensorFlow开始,现在进入代码:

这意味着每个图像只包含一个数字。现在让我们谈谈我们将使用的功能。当我们处理图像时,我们使用原始像素作为要素。那是因为提取有用的功能从图像,如纹理和形状,很难。现在28乘28的图像有784像素,所以我们有784个特征。在这里,我们使用扁平表示图像:

平整图像意味着将其从2D阵列转换通过拆除行并将它们排成一行来形成一维数组。这就是为什么我们不得不重塑这个阵列先显示它。现在我们可以初始化分类器了,在这里,我们将使用线性分类器。我们将提供两个参数:第一个表示我们有多少个种类,并且有10个(0到9的手写数字),每种类型的数字一个:

第二个通知分类器关于我们将使用的特征。我们绘制了一个线性分类器的快速图表提供高级预览了解它的工作原理。你可以想到分类器加上图像的证据每种类型的数字。输入节点位于顶部,由Xes表示,输出节点位于Ys表示的底部。我们为图像中的每个要素或像素都有一个输入节点,每个数字一个输出节点图像可以代表。在这里,我们有784个输入和10个输出:

现在输入和输出完全连接,并且这些边缘中的每一个都具有权重:

当我们对图像进行分类时,您可以考虑每个像素正在进行一次干扰。首先,它流入其输入节点,然后,它沿着边缘移动。一路上,它乘以边缘的权重,并且输出节点收集证据我们正在分类的图像代表每种类型的数字。我们收集的证据越多,就八个输出而言,图像越有可能是8:

并计算我们有多少证据,我们将像素强度的值相乘按权重。然后我们可以预测图像属于输出证据最充分的节点。重要的部分是权重,过正确设置,我们可以获得准确的分类。

我们从随机权重开始,然后逐渐调整它们,为了更好的体现这发生在fit方法中。一旦我们有一个训练有素的模型,我们就可以对其使用evaluate方法它正确地分类了大约90%的测试集,我们还可以对单个图像进行预测。

现在我想告诉你如何可视化权重分类器学习。这里,正权重用红色绘制,负权重用蓝色绘制:

那么这些权重告诉我们什么呢?要理解这一点我们将展示四张数字为1的图片:

它们都略有不同,但看看中间的像素。请注意,它已填入每个图像。当填充该像素时,它就是证明我们正在看的图像是一个,所以我们期待在那条边:

现在让我们来看看四个零:

请注意,中间像素为空:

虽然有很多方法可以绘制零,如果填充了中间像素,这是反对图像为零的证据,所以我们期望在边缘有负权重。并且看着权重的图像,我们几乎可以看到绘制的数字的轮廓每个类别都是红色的。我们能够想象这些,因为我们开始了有784像素,我们学会了10个权重,一个对于每种类型的数字。然后我们将权重重塑为2D数组。

文中代码块
代码语言:javascript复制
#!/usr/bin/env python# coding: utf-8
# # 使用tf.contrib.learn训练预测MNIST数据集# # 此代码针对TensorFlow 0.10.0rc0进行了测试。# 这里是docker镜像的地址: https://hub.docker.com/r/tensorflow/tensorflow/
# In[22]:

import numpy as npimport matplotlib.pyplot as pltget_ipython().run_line_magic('matplotlib', 'inline')import tensorflow as tffrom matplotlib import rcParams#设置图表字体,防止中文乱码rcParams['font.family'] = 'Microsoft YaHei'rcParams['font.sans-serif'] = 'Microsoft YaHei'learn = tf.contrib.learntf.logging.set_verbosity(tf.logging.ERROR)

# ## 导入数据集
# In[23]:

mnist = learn.datasets.load_dataset('mnist')data = mnist.train.imageslabels = np.asarray(mnist.train.labels, dtype=np.int32)test_data = mnist.test.imagestest_labels = np.asarray(mnist.test.labels, dtype=np.int32)

# 这里有55000个样本正在进行训练, 10000个进行测试,您可能希望限制大小以更快地进行实验。
# In[24]:

max_examples = 10000data = data[:max_examples]labels = labels[:max_examples]

# ## 显示一些数字
# In[25]:

def display(i):    img = test_data[i]    plt.title('示例  %d. 标签为: %d' % (i, test_labels[i]))    plt.imshow(img.reshape((28,28)), cmap=plt.cm.gray_r)

# In[26]:

display(0)

# In[27]:

display(1)

# 这些数字清晰可见。
# In[28]:

display(8)

# 现在让我们来看看我们有多少特征。
# In[29]:

print(len(data[0]))

# ## 训练线性分类器# # 我们的目标是通过这个简单的分类器获得大约90%的准确度。有关这些如何工作的更多详细信息,请参阅https://www.tensorflow.org/versions/r0.10/tutorials/mnist/beginners/index.html#mnist-for-ml-beginners
# In[30]:

feature_columns = learn.infer_real_valued_columns_from_input(data)classifier = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10)classifier.fit(data, labels, batch_size=100, steps=1000)

# ## 计算准确度
# In[31]:

classifier.evaluate(test_data, test_labels)print(classifier.evaluate(test_data, test_labels)["accuracy"])

# ## 可视化学习的权重# # # 让我们看看我们是否可以在TensorFlow Basic MNSIT中重现权重的图片 <a href="https://www.tensorflow.org/tutorials/mnist/beginners/index.html#mnist-for-ml-beginners">tutorial</a>.
# In[42]:

weights = classifier.get_variable_value("linear//weight/d/linear//weight/part_0/Ftrl_1")f, axes = plt.subplots(2, 5, figsize=(10,4))axes = axes.reshape(-1)for i in range(len(axes)):    a = axes[i]    a.imshow(weights.T[i].reshape(28, 28), cmap=plt.cm.seismic)    a.set_title(i)    a.set_xticks(()) # ticks be gone    a.set_yticks(())plt.show()

# # 扩展# # * TensorFlow Docker 镜像: https://hub.docker.com/r/tensorflow/tensorflow/ # * TF.Learn 快速学习指南: https://www.tensorflow.org/versions/r0.9/tutorials/tflearn/index.html# * MNIST 入门: https://www.tensorflow.org/tutorials/mnist/beginners/index.html# * 可视化 MNIST: http://colah.github.io/posts/2014-10-Visualizing-MNIST/# * TensorFlow Jupyter notebooks: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/docker/notebooks# * 更多关于linear classifiers分类器: https://www.tensorflow.org/versions/r0.10/tutorials/linear/overview.html#large-scale-linear-models-with-tensorflow# * 更多关于linear classifiers分类器: http://cs231n.github.io/linear-classify/# *  TF.Learn 示例: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/skflow

0 人点赞