TensorFlow-- Chapter06 MNIST手写数字识别
TensorFlow-- Chapter06 MNIST手写数字识别,tensorboard的使用。 作者:北山啦
文章目录
- TensorFlow-- Chapter06 MNIST手写数字识别
- 理论部分
- MNIST手写数字识别数据集
- 数据集的划分
- 拆分数据
- 工作流程
- 新的工作流程
- 逻辑回归
- Sigmod函数
- 损失函数
- 多元分类
- 实战代码
- TensorBoard可视化
- 利用TensorBoard可视化TensorFlow运行状态
- 产生日志文件
- 启动TensorBoard
- TensorBoard常用API总结
理论部分
MNIST手写数字识别数据集
其中包含了训练集 55000,验证集 5000,测试集 10000
数据集的划分
拆分数据
工作流程
新的工作流程
逻辑回归
Sigmod函数
损失函数
多元分类
softmax思想
实战代码
代码语言:javascript复制import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
代码语言:javascript复制import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets(r".dataMNIST_data", one_hot=True)
代码语言:javascript复制mnist[0]
代码语言:javascript复制<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet at 0x23dfc542a20>
代码语言:javascript复制print(mnist.train.num_examples)
mnist.test.num_examples
代码语言:javascript复制55000
10000
代码语言:javascript复制mnist.train.labels
代码语言:javascript复制array([[0., 0., 0., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.]])
代码语言:javascript复制print(mnist.train.images.shape)
mnist.test.images.shape
代码语言:javascript复制(55000, 784)
(10000, 784)
代码语言:javascript复制mnist.train.images[0]
代码语言:javascript复制array([0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.3803922 , 0.37647063, 0.3019608 ,
0.46274513, 0.2392157 , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.3529412 , 0.5411765 , 0.9215687 ,
0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 ,
0.9843138 , 0.9843138 , 0.9725491 , 0.9960785 , 0.9607844 ,
0.9215687 , 0.74509805, 0.08235294, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.54901963,
0.9843138 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.7411765 , 0.09019608, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.8862746 , 0.9960785 , 0.81568635,
0.7803922 , 0.7803922 , 0.7803922 , 0.7803922 , 0.54509807,
0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 ,
0.5019608 , 0.8705883 , 0.9960785 , 0.9960785 , 0.7411765 ,
0.08235294, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.14901961, 0.32156864, 0.0509804 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.13333334,
0.8352942 , 0.9960785 , 0.9960785 , 0.45098042, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.32941177, 0.9960785 ,
0.9960785 , 0.9176471 , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.32941177, 0.9960785 , 0.9960785 , 0.9176471 ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.4156863 , 0.6156863 ,
0.9960785 , 0.9960785 , 0.95294124, 0.20000002, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.09803922, 0.45882356, 0.8941177 , 0.8941177 ,
0.8941177 , 0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.9960785 , 0.94117653, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.26666668, 0.4666667 , 0.86274517,
0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.5568628 ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.14509805, 0.73333335,
0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 , 0.8745099 ,
0.8078432 , 0.8078432 , 0.29411766, 0.26666668, 0.8431373 ,
0.9960785 , 0.9960785 , 0.45882356, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.4431373 , 0.8588236 , 0.9960785 , 0.9490197 , 0.89019614,
0.45098042, 0.34901962, 0.12156864, 0. , 0. ,
0. , 0. , 0.7843138 , 0.9960785 , 0.9450981 ,
0.16078432, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.6627451 , 0.9960785 ,
0.6901961 , 0.24313727, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.18823531,
0.9058824 , 0.9960785 , 0.9176471 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.07058824, 0.48627454, 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.32941177, 0.9960785 , 0.9960785 ,
0.6509804 , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.54509807, 0.9960785 , 0.9333334 , 0.22352943, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.8235295 , 0.9803922 , 0.9960785 ,
0.65882355, 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.9490197 , 0.9960785 , 0.93725497, 0.22352943, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.34901962, 0.9843138 , 0.9450981 ,
0.3372549 , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.01960784,
0.8078432 , 0.96470594, 0.6156863 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01568628, 0.45882356, 0.27058825,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. ], dtype=float32)
可视化
代码语言:javascript复制import matplotlib.pyplot as plt
def plot_image(image):
plt.imshow(image.reshape(28, 28), cmap = 'binary')
plt.show()
代码语言:javascript复制plot_image(mnist.train.images[288])
代码语言:javascript复制x = tf.placeholder(tf.float32, [None, 784], name= "X")
y = tf.placeholder(tf.float32, [None, 10], name= "Y")
代码语言:javascript复制H1_NN = 256
W1 = tf.Variable(tf.random_normal([784, H1_NN]))
b1 = tf.Variable(tf.zeros([H1_NN]))
Y1 = tf.nn.relu(tf.matmul(x, W1) b1)
代码语言:javascript复制W2 = tf.Variable(tf.random_normal([H1_NN, 10]))
b2 = tf.Variable(tf.zeros([10]))
forward = tf.matmul(Y1, W2) b2
pred = tf.nn.softmax(forward)
代码语言:javascript复制loss_fuction = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=forward ,labels=y))
代码语言:javascript复制WARNING:tensorflow:From <ipython-input-13-1127016930ab>:1: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.
See @{tf.nn.softmax_cross_entropy_with_logits_v2}.
代码语言:javascript复制train_epochs = 40
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size)
display_step = 1
learning_rate = 0.01
代码语言:javascript复制optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_fuction)
代码语言:javascript复制correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(pred, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
代码语言:javascript复制correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(pred, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
image_shaped_input = tf.reshape(x,[-1,28,28,1])
tf.summary.image('input',image_shaped_input,10)
tf.summary.histogram('forward',forward)
tf.summary.scalar('loss',loss_fuction)
tf.summary.scalar('accuracy',accuracy)
代码语言:javascript复制<tf.Tensor 'accuracy:0' shape=() dtype=string>
代码语言:javascript复制from time import time
startTime = time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
merged_summary_op = tf.summary.merge_all()
writer = tf.summary.FileWriter('log/',sess.graph)
for epoch in range(train_epochs):
for batch in range(total_batch):
xs, ys = mnist.train.next_batch(batch_size)
sess.run(optimizer, feed_dict={x: xs, y: ys})
summary_str = sess.run(merged_summary_op,feed_dict={x:xs,y:ys})
writer.add_summary(summary_str,epoch)
loss,acc = sess.run([loss_fuction,accuracy], feed_dict={x: mnist.validation.images,
y: mnist.validation.labels})
if (epoch 1) % display_step == 0:
print("Train Epoch", 'd' % (epoch 1),
"Loss=", "{:.9f}".format(loss), "Accuracy=", "{:.4f}".format(acc))
duration = time() - startTime
print("Train Finished takes:" "{:.2f}".format(duration))
代码语言:javascript复制Train Epoch 01 Loss= 1.259438753 Accuracy= 0.9368
Train Epoch 02 Loss= 0.717698812 Accuracy= 0.9446
Train Epoch 03 Loss= 0.575311124 Accuracy= 0.9472
Train Epoch 04 Loss= 0.448238075 Accuracy= 0.9552
Train Epoch 05 Loss= 0.413602978 Accuracy= 0.9506
Train Epoch 06 Loss= 0.428873390 Accuracy= 0.9518
Train Epoch 07 Loss= 0.398006409 Accuracy= 0.9592
Train Epoch 08 Loss= 0.290548950 Accuracy= 0.9694
Train Epoch 09 Loss= 0.370046228 Accuracy= 0.9640
Train Epoch 10 Loss= 0.360535949 Accuracy= 0.9634
Train Epoch 11 Loss= 0.458259851 Accuracy= 0.9576
Train Epoch 12 Loss= 0.346073866 Accuracy= 0.9626
Train Epoch 13 Loss= 0.486990929 Accuracy= 0.9626
TensorBoard可视化
代码语言:javascript复制#x = tf.placeholder(tf.float32,[None,784],name="X")
image_shaped_input = tf.reshape(x,[-1,28,28,1])
代码语言:javascript复制tf.summary.image("input",image_shaped_input,10)
代码语言:javascript复制tf.summary.histogram("forward",x)
将loss损失以标量显示
代码语言:javascript复制tf.summary.scalar("loss",loss)
将accruacy标准率以标量显示
代码语言:javascript复制tf.summary.scalar("accuracy",accuracy)
训练模型
代码语言:javascript复制sess = tf.Session()
sess.run(tf.global_variables_initializer())
合并所有的summary
代码语言:javascript复制merged_summary_op = tf.summary.merge_all()
writer = tf.summary.FileWriter("log/",tf.get_default_graph())
TensorBoard
代码语言:javascript复制tf.reset_default_graph()
for epoch in range(train_epochs):
for batch in range(total_batch):
xs,ys = mnist.train.next_batch(batch_size)
sess.run(optimizer,feed_dict = {x:xs,y:ys})
summay_str = sess.run(merged_summary_op,feed_dict={x:xs,y:ys})
writer.add_summary(summary_str,eopch)
loss,acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,
y:mnist.validation.lables})
代码语言:javascript复制
利用TensorBoard可视化TensorFlow运行状态
- TensorBoard是TensorFlow的可视化工具
- 通过Tensor Flow程序运行过程中输出的日志文件可视化TensorFlow程序的运行状态
- TensorBoard和TensorFlow程序跑在不同的进程中
产生日志文件
- tf.reset_default_graph():清除default graph和不断增加的节点
启动TensorBoard
- 在Anaconda Prompt中进入日志存放的目录
- 运行TensorBoard 将日志的地址只想程序日志输出的地址
tensorboard --logdir=D:log
3. 通过给定的网址,进入即可
TensorBoard常用API总结
到这里就结束了,如果对你有帮助,欢迎点赞关注评论,你的点赞对我很重要。作者:北山啦