目录
0.编程环境
1、下载并解压数据集
2、完整代码
3、数据准备
4、数据观察
4.1 查看变量mnist的方法和属性
4.2 对比三个集合
4.3 mnist.train.images观察
4.4 查看手写数字图
5、搭建神经网络
6、变量初始化
7、模型训练
9、模型测试
MNIST是Mixed National Institue of Standards and Technology database的简称,中文叫做美国国家标准与技术研究所数据库。
0.编程环境
安装tensorflow命令:pip install tensorflow
操作系统:Win10
python版本:3.6
集成开发环境:jupyter notebook
tensorflow版本:1.6
1、下载并解压数据集
MNIST数据集下载链接: https://pan.baidu.com/s/1fPbgMqsEvk2WyM9hy5Em6w 密码: wa9p 下载压缩文件MNIST_data.rar完成后,选择解压到当前文件夹,不要选择解压到MNIST_data。 文件夹结构如下图所示:
2、完整代码
此章给读者能够直接运行的完整代码,使读者有编程结果的感性认识。 如果下面一段代码运行成功,则说明安装tensorflow环境成功。 想要了解代码的具体实现细节,请阅读后面的章节。
代码语言:javascript复制import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)
Weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([1,10]))
predict_y = tf.nn.softmax(tf.matmul(X_holder, Weights) biases)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
session = tf.Session()
init = tf.global_variables_initializer()
session.run(init)
for i in range(500):
images, labels = mnist.train.next_batch(batch_size)
session.run(train, feed_dict={X_holder:images, y_holder:labels})
if i % 25 == 0:
correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})
print('step:%d accuracy:%.4f' %(i, accuracy_value))
上面一段代码的运行结果如下:
Extracting MNIST_datatrain-images-idx3-ubyte.gz Extracting MNIST_datatrain-labels-idx1-ubyte.gz Extracting MNIST_datat10k-images-idx3-ubyte.gz Extracting MNIST_datat10k-labels-idx1-ubyte.gz step:0 accuracy:0.4747 step:25 accuracy:0.8553 step:50 accuracy:0.8719 step:75 accuracy:0.8868 step:100 accuracy:0.8911 step:125 accuracy:0.8998 step:150 accuracy:0.8942 step:175 accuracy:0.9050 step:200 accuracy:0.9026 step:225 accuracy:0.9076 step:250 accuracy:0.9071 step:275 accuracy:0.9049 step:300 accuracy:0.9055 step:325 accuracy:0.9101 step:350 accuracy:0.9097 step:375 accuracy:0.9116 step:400 accuracy:0.9102 step:425 accuracy:0.9113 step:450 accuracy:0.9155 step:475 accuracy:0.9151
从上面的运行结果可以看出,经过500步训练,模型准确率到达0.9151左右。
3、数据准备
代码语言:javascript复制import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batch_size = 100
X_holder = tf.placeholder(tf.float32)
y_holder = tf.placeholder(tf.float32)
第1行代码导入warnings库,第2行代码表示不打印警告信息;
第3行代码导入tensorflow库,取别名tf;
第4行代码人从tensorflow.examples.tutorials.mnist库中导入input_data文件;
本文作者使用anaconda集成开发环境,input_data文件所在路径:C:ProgramDataAnaconda3Libsite-packagestensorflowexamplestutorialsmnist
,如下图所示:
第6行代码调用input_data文件的read_data_sets方法,需要2个参数,第1个参数的数据类型是字符串,是读取数据的文件夹名,第2个关键字参数ont_hot数据类型为布尔bool,设置为True,表示预测目标值是否经过One-Hot编码; 第7行代码定义变量batch_size的值为100; 第8、9行代码中placeholder中文叫做占位符,将每次训练的特征矩阵X和预测目标值y赋值给变量X_holder和y_holder。
4、数据观察
本章内容主要是了解变量mnist中的数据内容,并掌握变量mnist中的方法使用。
4.1 查看变量mnist的方法和属性
代码语言:javascript复制dir(mnist)[-10:]
上面一段代码的运行结果如下:
['_asdict', '_fields', '_make', '_replace', '_source', 'count', 'index', 'test', 'train', 'validation']
为了节省篇幅,只打印最后10个方法和属性。 我们会用到的是其中test、train、validation这3个方法。
4.2 对比三个集合
train对应训练集,validation对应验证集,test对应测试集。 查看3个集合中的样本数量,代码如下:
代码语言:javascript复制print(mnist.train.num_examples)
print(mnist.validation.num_examples)
print(mnist.test.num_examples)
上面一段代码的运行结果如下:
55000 5000 10000
对比3个集合的方法和属性
从上面的运行结果可以看出,3个集合的方法和属性基本相同。 我们会用到的是其中images、labels、next_batch这3个属性或方法。
4.3 mnist.train.images观察
查看mnist.train.images的数据类型和矩阵形状。
代码语言:javascript复制images = mnist.train.images
type(images), images.shape
上面一段代码的运行结果如下:
(numpy.ndarray, (55000, 784))
从上面的运行结果可以看出,在变量mnist.train中总共有55000个样本,每个样本有784个特征。
原图片形状为28*28,28*28=784
,每个图片样本展平后则有784维特征。
选取1个样本,用3种作图方式查看其图片内容,代码如下:
import matplotlib.pyplot as plt
image = mnist.train.images[1].reshape(-1, 28)
plt.subplot(131)
plt.imshow(image)
plt.axis('off')
plt.subplot(132)
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.subplot(133)
plt.imshow(image, cmap='gray_r')
plt.axis('off')
plt.show()
上面一段代码的运行结果如下图所示:
从上面的运行结果可以看出,调用plt.show方法时,参数cmap指定值为gray或gray_r符合正常的观看效果。
4.4 查看手写数字图
从训练集mnist.train中选取一部分样本查看图片内容,即调用mnist.train的next_batch方法随机获得一部分样本,代码如下:
代码语言:javascript复制import matplotlib.pyplot as plt
import math
import numpy as np
def drawDigit(position, image, title):
plt.subplot(*position)
plt.imshow(image.reshape(-1, 28), cmap='gray_r')
plt.axis('off')
plt.title(title)
def batchDraw(batch_size):
images,labels = mnist.train.next_batch(batch_size)
image_number = images.shape[0]
row_number = math.ceil(image_number ** 0.5)
column_number = row_number
plt.figure(figsize=(row_number, column_number))
for i in range(row_number):
for j in range(column_number):
index = i * column_number j
if index < image_number:
position = (row_number, column_number, index 1)
image = images[index]
title = 'actual:%d' %(np.argmax(labels[index]))
drawDigit(position, image, title)
batchDraw(196)
plt.show()
上面一段代码的运行结果如下图所示,本文作者对难以辨认的数字做了红色方框标注:
5、搭建神经网络
代码语言:javascript复制Weights = tf.Variable(tf.zeros([784, 10]))
biases = tf.Variable(tf.zeros([1,10]))
predict_y = tf.nn.softmax(tf.matmul(X_holder, Weights) biases)
loss = tf.reduce_mean(-tf.reduce_sum(y_holder * tf.log(predict_y), 1))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
该神经网络只有输入层和输出层,没有隐藏层。
第1行代码定义形状为784*10
的权重矩阵Weights;
第2行代码定义形状为1*10
的偏置矩阵biases;
第3行代码定义先通过矩阵计算,再使用激活函数softmax得出的每个分类的预测概率predict_y;
第4行代码定义损失函数loss,多分类问题使用交叉熵作为损失函数。
交叉熵的函数如下图所示,其中p(x)是实际值,q(x)是预测值。
第5行代码定义优化器optimizer,使用梯度下降优化器; 第6行代码定义训练步骤train,即最小化损失。
6、变量初始化
代码语言:javascript复制init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)
对于神经网络模型,重要是其中的W、b这两个参数。 开始神经网络模型训练之前,这两个变量需要初始化。 第1行代码调用tf.global_variables_initializer实例化tensorflow中的Operation对象。
第2行代码调用tf.Session方法实例化会话对象; 第3行代码调用tf.Session对象的run方法做变量初始化。
7、模型训练
代码语言:javascript复制for i in range(500):
images, labels = mnist.train.next_batch(batch_size)
session.run(train, feed_dict={X_holder:images, y_holder:labels})
if i % 25 == 0:
correct_prediction = tf.equal(tf.argmax(predict_y, 1), tf.argmax(y_holder, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
accuracy_value = session.run(accuracy, feed_dict={X_holder:mnist.test.images, y_holder:mnist.test.labels})
print('step:%d accuracy:%.4f' %(i, accuracy_value))
第1行代码表示模型迭代训练500次; 第2行代码调用mnist.train对象的next_batch方法,选出数量为batch_size的样本; 第3行代码是模型训练,每运行1次此行代码,即模型训练1次; 第4-8行代码是每隔25次训练打印模型准确率。 上面一段代码的运行结果如下:
step:0 accuracy:0.3161 step:25 accuracy:0.8452 step:50 accuracy:0.8668 step:75 accuracy:0.8860 step:100 accuracy:0.8906 step:125 accuracy:0.8948 step:150 accuracy:0.9008 step:175 accuracy:0.9027 step:200 accuracy:0.8956 step:225 accuracy:0.9102 step:250 accuracy:0.9022 step:275 accuracy:0.9097 step:300 accuracy:0.9039 step:325 accuracy:0.9076 step:350 accuracy:0.9137 step:375 accuracy:0.9111 step:400 accuracy:0.9069 step:425 accuracy:0.9097 step:450 accuracy:0.9150 step:475 accuracy:0.9105
9、模型测试
代码语言:javascript复制import math
import matplotlib.pyplot as plt
import numpy as np
def drawDigit2(position, image, title, isTrue):
plt.subplot(*position)
plt.imshow(image.reshape(-1, 28), cmap='gray_r')
plt.axis('off')
if not isTrue:
plt.title(title, color='red')
else:
plt.title(title)
def batchDraw2(batch_size):
images,labels = mnist.test.next_batch(batch_size)
predict_labels = session.run(predict_y, feed_dict={X_holder:images, y_holder:labels})
image_number = images.shape[0]
row_number = math.ceil(image_number ** 0.5)
column_number = row_number
plt.figure(figsize=(row_number 8, column_number 8))
for i in range(row_number):
for j in range(column_number):
index = i * column_number j
if index < image_number:
position = (row_number, column_number, index 1)
image = images[index]
actual = np.argmax(labels[index])
predict = np.argmax(predict_labels[index])
isTrue = actual==predict
title = 'actual:%dnpredict:%d' %(actual,predict)
drawDigit2(position, image, title, isTrue)
batchDraw2(100)
plt.show()
上面一段代码的运行结果如下图所示: