用Tensorflow识别手写体

2020-06-12 09:11:41 浏览数 (4)

数据准备

代码语言:javascript复制
import tensorflow as tfimport tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
代码语言:javascript复制
WARNING:tensorflow:From <ipython-input-1-6bfbaa60ed82>:3: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use alternatives such as official/mnist/dataset.py from tensorflow/models.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.Instructions for updating:Please write your own downloading logic.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.Instructions for updating:Please use urllib or similar directly.Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.data to implement this functionality.Extracting MNIST_data/train-images-idx3-ubyte.gzSuccessfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.data to implement this functionality.Extracting MNIST_data/train-labels-idx1-ubyte.gzWARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.one_hot on tensors.Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.Extracting MNIST_data/t10k-images-idx3-ubyte.gzSuccessfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.Extracting MNIST_data/t10k-labels-idx1-ubyte.gzWARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use alternatives such as official/mnist/dataset.py from tensorflow/models.

建立共享函数

定义weight函数
代码语言:javascript复制
def weight(shape):    return tf.Variable(tf.truncated_normal(shape, stddev=0.1),                       name ='W')
定义bias函数
代码语言:javascript复制
def bias(shape):    return tf.Variable(tf.constant(0.1, shape=shape)                       , name = 'b')
定义conv2d函数
代码语言:javascript复制
def conv2d(x, W):    return tf.nn.conv2d(x, W, strides=[1,1,1,1], #filter每次移动时从左到右,从上到下各一步                        padding='SAME')
建立池化函数
代码语言:javascript复制
def max_pool_2x2(x):    return tf.nn.max_pool(x, ksize=[1,2,2,1], #设置采样窗口的大小,height=2,width=2                          strides=[1,2,2,1], #设置步长,从左到右,从上到下个两步                          padding='SAME')

建立模型

输入层 Input Layer

x_image的参数说明

  • 第一维是-1:因为后续通过placeholder输入的参数的个数不一定,所以设置为-1
  • 第二维和第三维是28,28:输入的数字大小是28*28
  • 第四维是1,因为是单色,所以设置为1,如果是彩色设置为3
代码语言:javascript复制
with tf.name_scope('Input_Layer'):#设置计算图的输入名称    x = tf.placeholder("float",shape=[None, 784]                       ,name="x")        x_image = tf.reshape(x, [-1, 28, 28, 1])
Convolutional Layer 1

W1参数的解释

  • 第一维和第二维均是5:代表filter的大小是5*5
  • 第三维是1:单色设置为1,彩色设置为3
  • 第四维是16:要产生16个图像
代码语言:javascript复制
with tf.name_scope('C1_Conv'):    W1 = weight([5,5,1,16])    b1 = bias([16])    Conv1=conv2d(x_image, W1)  b1    C1_Conv = tf.nn.relu(Conv1 )
代码语言:javascript复制
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.Instructions for updating:Colocations handled automatically by placer.

建立池化层函数的好处

  • 减少所需要处理的数据点
  • 让图像位置的差异变小
  • 参数的数量和计算量下降
代码语言:javascript复制
with tf.name_scope('C1_Pool'):    C1_Pool = max_pool_2x2(C1_Conv)
Convolutional Layer 2
代码语言:javascript复制
with tf.name_scope('C2_Conv'):    W2 = weight([5,5,16,36])#将原来的16个图像转换为36个    b2 = bias([36])    Conv2=conv2d(C1_Pool, W2)  b2    C2_Conv = tf.nn.relu(Conv2)
代码语言:javascript复制
with tf.name_scope('C2_Pool'):    C2_Pool = max_pool_2x2(C2_Conv) 
Fully Connected Layer

D_Flat参数的解释

  • C2_Pool:此参数为要进行的reshape张量
  • 列表第一维-1:因为传入的是不限定项数的训练数据
  • 列表第二维1764:因为传入的张量是36个7*7的图像
代码语言:javascript复制
with tf.name_scope('D_Flat'):    D_Flat = tf.reshape(C2_Pool, [-1, 1764])
代码语言:javascript复制
with tf.name_scope('D_Hidden_Layer'):    W3= weight([1764, 128])#隐藏层的神经元个数为128    b3= bias([128])    D_Hidden = tf.nn.relu(                  tf.matmul(D_Flat, W3) b3)    D_Hidden_Dropout= tf.nn.dropout(D_Hidden,                                 keep_prob=0.8)#要保留的神经元的比例
代码语言:javascript复制
WARNING:tensorflow:From <ipython-input-12-b635345e166c>:7: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.Instructions for updating:Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
输出层
代码语言:javascript复制
with tf.name_scope('Output_Layer'):    W4 = weight([128,10])    b4 = bias([10])    y_predict= tf.nn.softmax(                 tf.matmul(D_Hidden_Dropout,                           W4) b4)

设置训练模型最优化步骤

代码语言:javascript复制
with tf.name_scope("optimizer"):    y_label = tf.placeholder("float", shape=[None, 10],                               name="y_label")    loss_function = tf.reduce_mean(                      tf.nn.softmax_cross_entropy_with_logits_v2                         (logits=y_predict ,                           labels=y_label))    optimizer = tf.train.AdamOptimizer(learning_rate=0.0001)                     .minimize(loss_function)

设置评估模型

代码语言:javascript复制
with tf.name_scope("evaluate_model"):    correct_prediction = tf.equal(tf.argmax(y_predict, 1),                                  tf.argmax(y_label, 1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

训练模型

代码语言:javascript复制
trainEpochs = 10batchSize = 100totalBatchs = int(mnist.train.num_examples/batchSize)epoch_list=[];accuracy_list=[];loss_list=[];from time import timestartTime=time()sess = tf.Session()sess.run(tf.global_variables_initializer())
代码语言:javascript复制
for epoch in range(trainEpochs):    for i in range(totalBatchs):        batch_x, batch_y = mnist.train.next_batch(batchSize)        sess.run(optimizer,feed_dict={x: batch_x,                                      y_label: batch_y})    loss,acc = sess.run([loss_function,accuracy],                        feed_dict={x: mnist.validation.images,                                    y_label: mnist.validation.labels})    epoch_list.append(epoch)    loss_list.append(loss);accuracy_list.append(acc)        print("Train Epoch:", 'd' % (epoch 1),           "Loss=","{:.9f}".format(loss)," Accuracy=",acc)duration =time()-startTimeprint("Train Finished takes:",duration)         
代码语言:javascript复制
Train Epoch: 01 Loss= 1.656932473  Accuracy= 0.827Train Epoch: 02 Loss= 1.613922596  Accuracy= 0.8558Train Epoch: 03 Loss= 1.598174453  Accuracy= 0.8692Train Epoch: 04 Loss= 1.510785699  Accuracy= 0.9574Train Epoch: 05 Loss= 1.500687838  Accuracy= 0.9658Train Epoch: 06 Loss= 1.495839953  Accuracy= 0.9684Train Epoch: 07 Loss= 1.491830468  Accuracy= 0.9726Train Epoch: 08 Loss= 1.489337087  Accuracy= 0.9742Train Epoch: 09 Loss= 1.486868739  Accuracy= 0.9774Train Epoch: 10 Loss= 1.484916449  Accuracy= 0.9792Train Finished takes: 720.7906568050385
代码语言:javascript复制
%matplotlib inlineimport matplotlib.pyplot as pltfig = plt.gcf()fig.set_size_inches(4,2)plt.plot(epoch_list, loss_list, label = 'loss')plt.ylabel('loss')plt.xlabel('epoch')plt.legend(['loss'], loc='upper left')
代码语言:javascript复制
<matplotlib.legend.Legend at 0x7fd9d2ab8b38>
代码语言:javascript复制
plt.plot(epoch_list, accuracy_list,label="accuracy" )fig = plt.gcf()fig.set_size_inches(4,2)plt.ylim(0.8,1)plt.ylabel('accuracy')plt.xlabel('epoch')plt.legend()plt.show()

评估模型的准确率

代码语言:javascript复制
len(mnist.test.images)
代码语言:javascript复制
10000
代码语言:javascript复制
print("Accuracy:",       sess.run(accuracy,feed_dict={x: mnist.test.images,                                   y_label: mnist.test.labels}))
代码语言:javascript复制
Accuracy: 0.9792
代码语言:javascript复制
print("Accuracy:",       sess.run(accuracy,feed_dict={x: mnist.test.images[:5000],                                   y_label: mnist.test.labels[:5000]}))
代码语言:javascript复制
Accuracy: 0.968
代码语言:javascript复制
print("Accuracy:",       sess.run(accuracy,feed_dict={x: mnist.test.images[5000:],                                   y_label: mnist.test.labels[5000:]}))
代码语言:javascript复制
Accuracy: 0.9886

预测概率

代码语言:javascript复制
y_predict=sess.run(y_predict,                    feed_dict={x: mnist.test.images[:5000]})
代码语言:javascript复制
y_predict[:5]
代码语言:javascript复制
array([[4.05578522e-12, 6.15486123e-14, 5.71559293e-12, 1.74847949e-11,        2.71332728e-17, 8.90746643e-11, 5.53451119e-21, 1.00000000e 00,        1.10875556e-13, 8.30471913e-10],       [9.93732328e-07, 4.50552989e-06, 9.99993682e-01, 8.04418278e-07,        2.64564185e-14, 1.46194583e-14, 1.56614929e-10, 6.01911912e-14,        3.10939221e-08, 2.34203085e-15],       [1.31195605e-08, 9.99897718e-01, 4.33765905e-07, 2.02467453e-11,        9.89620676e-05, 6.53400056e-10, 6.65772149e-08, 2.67320161e-06,        5.65030227e-08, 4.94121100e-09],       [9.99993682e-01, 6.49839280e-11, 3.86714616e-09, 2.97008674e-13,        8.59991689e-10, 1.07891083e-11, 6.35852575e-06, 1.65313943e-10,        2.73128520e-10, 5.31917976e-08],       [1.19434844e-06, 9.74953984e-09, 2.05678519e-09, 2.03244167e-14,        9.99985814e-01, 1.02013356e-10, 7.86321621e-08, 3.65643515e-08,        6.86227242e-10, 1.28641732e-05]], dtype=float32)

预测结果

代码语言:javascript复制
prediction_result=sess.run(tf.argmax(y_predict,1),                           feed_dict={x: mnist.test.images ,                                      y_label: mnist.test.labels})
代码语言:javascript复制
prediction_result[:10]
代码语言:javascript复制
array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
代码语言:javascript复制
import numpy as npdef show_images_labels_predict(images,labels,prediction_result):    fig = plt.gcf()    fig.set_size_inches(8, 10)    for i in range(0, 10):        ax=plt.subplot(5,5, 1 i)        ax.imshow(np.reshape(images[i],(28, 28)),                   cmap='binary')        ax.set_title("label="  str(np.argmax(labels[i]))                      ",predict=" str(prediction_result[i])                     ,fontsize=9)     plt.show()
代码语言:javascript复制
show_images_labels_predict(mnist.test.images,mnist.test.labels,prediction_result)

找出预测错误

代码语言:javascript复制
for i in range(500):    if prediction_result[i]!=np.argmax(mnist.test.labels[i]):        print("i=" str(i)               "   label=",np.argmax(mnist.test.labels[i]),              "predict=",prediction_result[i])
代码语言:javascript复制
i=247   label= 4 predict= 2i=259   label= 6 predict= 0i=290   label= 8 predict= 4i=320   label= 9 predict= 1i=321   label= 2 predict= 7i=340   label= 5 predict= 3i=445   label= 6 predict= 0i=495   label= 8 predict= 0
代码语言:javascript复制
def show_images_labels_predict_error(images,labels,prediction_result):    fig = plt.gcf()    fig.set_size_inches(8, 10)    i=0;j=0    while i<10:        if prediction_result[j]!=np.argmax(labels[j]):            ax=plt.subplot(5,5, 1 i)            ax.imshow(np.reshape(images[j],(28, 28)),                       cmap='binary')            ax.set_title("j=" str(j)                          ",l="  str(np.argmax(labels[j]))                          ",p=" str(prediction_result[j])                         ,fontsize=9)             i=i 1          j=j 1    plt.show()
代码语言:javascript复制
show_images_labels_predict_error(mnist.test.images,mnist.test.labels,prediction_result)

保存模型

代码语言:javascript复制
saver = tf.train.Saver()
代码语言:javascript复制
save_path = saver.save(sess, "saveModel/CNN_model1")
代码语言:javascript复制
print("Model saved in file: %s" % save_path)
代码语言:javascript复制
Model saved in file: saveModel/CNN_model1

启动TensorBoard

  • tensorboard --logdir=c:pyhonworklogCNN
  • 在浏览器中打开https://localhost:6006/
代码语言:javascript复制
merged = tf.summary.merge_all()train_writer = tf.summary.FileWriter('log/CNN',sess.graph)
代码语言:javascript复制
sess.close()

0 人点赞