数据准备
代码语言:javascript复制import tensorflow as tfimport tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
代码语言:javascript复制WARNING:tensorflow:From <ipython-input-1-6bfbaa60ed82>:3: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use alternatives such as official/mnist/dataset.py from tensorflow/models.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.Instructions for updating:Please write your own downloading logic.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.Instructions for updating:Please use urllib or similar directly.Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.data to implement this functionality.Extracting MNIST_data/train-images-idx3-ubyte.gzSuccessfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.data to implement this functionality.Extracting MNIST_data/train-labels-idx1-ubyte.gzWARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use tf.one_hot on tensors.Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.Extracting MNIST_data/t10k-images-idx3-ubyte.gzSuccessfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.Extracting MNIST_data/t10k-labels-idx1-ubyte.gzWARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.Instructions for updating:Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
建立共享函数
定义weight函数
代码语言:javascript复制def weight(shape): return tf.Variable(tf.truncated_normal(shape, stddev=0.1), name ='W')
定义bias函数
代码语言:javascript复制def bias(shape): return tf.Variable(tf.constant(0.1, shape=shape) , name = 'b')
定义conv2d函数
代码语言:javascript复制def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1,1,1,1], #filter每次移动时从左到右,从上到下各一步 padding='SAME')
建立池化函数
代码语言:javascript复制def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1,2,2,1], #设置采样窗口的大小,height=2,width=2 strides=[1,2,2,1], #设置步长,从左到右,从上到下个两步 padding='SAME')
建立模型
输入层 Input Layer
x_image的参数说明
- 第一维是-1:因为后续通过placeholder输入的参数的个数不一定,所以设置为-1
- 第二维和第三维是28,28:输入的数字大小是28*28
- 第四维是1,因为是单色,所以设置为1,如果是彩色设置为3
with tf.name_scope('Input_Layer'):#设置计算图的输入名称 x = tf.placeholder("float",shape=[None, 784] ,name="x") x_image = tf.reshape(x, [-1, 28, 28, 1])
Convolutional Layer 1
W1参数的解释
- 第一维和第二维均是5:代表filter的大小是5*5
- 第三维是1:单色设置为1,彩色设置为3
- 第四维是16:要产生16个图像
with tf.name_scope('C1_Conv'): W1 = weight([5,5,1,16]) b1 = bias([16]) Conv1=conv2d(x_image, W1) b1 C1_Conv = tf.nn.relu(Conv1 )
代码语言:javascript复制WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.Instructions for updating:Colocations handled automatically by placer.
建立池化层函数的好处
- 减少所需要处理的数据点
- 让图像位置的差异变小
- 参数的数量和计算量下降
with tf.name_scope('C1_Pool'): C1_Pool = max_pool_2x2(C1_Conv)
Convolutional Layer 2
代码语言:javascript复制with tf.name_scope('C2_Conv'): W2 = weight([5,5,16,36])#将原来的16个图像转换为36个 b2 = bias([36]) Conv2=conv2d(C1_Pool, W2) b2 C2_Conv = tf.nn.relu(Conv2)
代码语言:javascript复制with tf.name_scope('C2_Pool'): C2_Pool = max_pool_2x2(C2_Conv)
Fully Connected Layer
D_Flat参数的解释
- C2_Pool:此参数为要进行的reshape张量
- 列表第一维-1:因为传入的是不限定项数的训练数据
- 列表第二维1764:因为传入的张量是36个7*7的图像
with tf.name_scope('D_Flat'): D_Flat = tf.reshape(C2_Pool, [-1, 1764])
代码语言:javascript复制with tf.name_scope('D_Hidden_Layer'): W3= weight([1764, 128])#隐藏层的神经元个数为128 b3= bias([128]) D_Hidden = tf.nn.relu( tf.matmul(D_Flat, W3) b3) D_Hidden_Dropout= tf.nn.dropout(D_Hidden, keep_prob=0.8)#要保留的神经元的比例
代码语言:javascript复制WARNING:tensorflow:From <ipython-input-12-b635345e166c>:7: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.Instructions for updating:Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
输出层
代码语言:javascript复制with tf.name_scope('Output_Layer'): W4 = weight([128,10]) b4 = bias([10]) y_predict= tf.nn.softmax( tf.matmul(D_Hidden_Dropout, W4) b4)
设置训练模型最优化步骤
代码语言:javascript复制with tf.name_scope("optimizer"): y_label = tf.placeholder("float", shape=[None, 10], name="y_label") loss_function = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2 (logits=y_predict , labels=y_label)) optimizer = tf.train.AdamOptimizer(learning_rate=0.0001) .minimize(loss_function)
设置评估模型
代码语言:javascript复制with tf.name_scope("evaluate_model"): correct_prediction = tf.equal(tf.argmax(y_predict, 1), tf.argmax(y_label, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
训练模型
代码语言:javascript复制trainEpochs = 10batchSize = 100totalBatchs = int(mnist.train.num_examples/batchSize)epoch_list=[];accuracy_list=[];loss_list=[];from time import timestartTime=time()sess = tf.Session()sess.run(tf.global_variables_initializer())
代码语言:javascript复制for epoch in range(trainEpochs): for i in range(totalBatchs): batch_x, batch_y = mnist.train.next_batch(batchSize) sess.run(optimizer,feed_dict={x: batch_x, y_label: batch_y}) loss,acc = sess.run([loss_function,accuracy], feed_dict={x: mnist.validation.images, y_label: mnist.validation.labels}) epoch_list.append(epoch) loss_list.append(loss);accuracy_list.append(acc) print("Train Epoch:", 'd' % (epoch 1), "Loss=","{:.9f}".format(loss)," Accuracy=",acc)duration =time()-startTimeprint("Train Finished takes:",duration)
代码语言:javascript复制Train Epoch: 01 Loss= 1.656932473 Accuracy= 0.827Train Epoch: 02 Loss= 1.613922596 Accuracy= 0.8558Train Epoch: 03 Loss= 1.598174453 Accuracy= 0.8692Train Epoch: 04 Loss= 1.510785699 Accuracy= 0.9574Train Epoch: 05 Loss= 1.500687838 Accuracy= 0.9658Train Epoch: 06 Loss= 1.495839953 Accuracy= 0.9684Train Epoch: 07 Loss= 1.491830468 Accuracy= 0.9726Train Epoch: 08 Loss= 1.489337087 Accuracy= 0.9742Train Epoch: 09 Loss= 1.486868739 Accuracy= 0.9774Train Epoch: 10 Loss= 1.484916449 Accuracy= 0.9792Train Finished takes: 720.7906568050385
代码语言:javascript复制%matplotlib inlineimport matplotlib.pyplot as pltfig = plt.gcf()fig.set_size_inches(4,2)plt.plot(epoch_list, loss_list, label = 'loss')plt.ylabel('loss')plt.xlabel('epoch')plt.legend(['loss'], loc='upper left')
代码语言:javascript复制<matplotlib.legend.Legend at 0x7fd9d2ab8b38>
代码语言:javascript复制plt.plot(epoch_list, accuracy_list,label="accuracy" )fig = plt.gcf()fig.set_size_inches(4,2)plt.ylim(0.8,1)plt.ylabel('accuracy')plt.xlabel('epoch')plt.legend()plt.show()
评估模型的准确率
代码语言:javascript复制len(mnist.test.images)
代码语言:javascript复制10000
代码语言:javascript复制print("Accuracy:", sess.run(accuracy,feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
代码语言:javascript复制Accuracy: 0.9792
代码语言:javascript复制print("Accuracy:", sess.run(accuracy,feed_dict={x: mnist.test.images[:5000], y_label: mnist.test.labels[:5000]}))
代码语言:javascript复制Accuracy: 0.968
代码语言:javascript复制print("Accuracy:", sess.run(accuracy,feed_dict={x: mnist.test.images[5000:], y_label: mnist.test.labels[5000:]}))
代码语言:javascript复制Accuracy: 0.9886
预测概率
代码语言:javascript复制y_predict=sess.run(y_predict, feed_dict={x: mnist.test.images[:5000]})
代码语言:javascript复制y_predict[:5]
代码语言:javascript复制array([[4.05578522e-12, 6.15486123e-14, 5.71559293e-12, 1.74847949e-11, 2.71332728e-17, 8.90746643e-11, 5.53451119e-21, 1.00000000e 00, 1.10875556e-13, 8.30471913e-10], [9.93732328e-07, 4.50552989e-06, 9.99993682e-01, 8.04418278e-07, 2.64564185e-14, 1.46194583e-14, 1.56614929e-10, 6.01911912e-14, 3.10939221e-08, 2.34203085e-15], [1.31195605e-08, 9.99897718e-01, 4.33765905e-07, 2.02467453e-11, 9.89620676e-05, 6.53400056e-10, 6.65772149e-08, 2.67320161e-06, 5.65030227e-08, 4.94121100e-09], [9.99993682e-01, 6.49839280e-11, 3.86714616e-09, 2.97008674e-13, 8.59991689e-10, 1.07891083e-11, 6.35852575e-06, 1.65313943e-10, 2.73128520e-10, 5.31917976e-08], [1.19434844e-06, 9.74953984e-09, 2.05678519e-09, 2.03244167e-14, 9.99985814e-01, 1.02013356e-10, 7.86321621e-08, 3.65643515e-08, 6.86227242e-10, 1.28641732e-05]], dtype=float32)
预测结果
代码语言:javascript复制prediction_result=sess.run(tf.argmax(y_predict,1), feed_dict={x: mnist.test.images , y_label: mnist.test.labels})
代码语言:javascript复制prediction_result[:10]
代码语言:javascript复制array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
代码语言:javascript复制import numpy as npdef show_images_labels_predict(images,labels,prediction_result): fig = plt.gcf() fig.set_size_inches(8, 10) for i in range(0, 10): ax=plt.subplot(5,5, 1 i) ax.imshow(np.reshape(images[i],(28, 28)), cmap='binary') ax.set_title("label=" str(np.argmax(labels[i])) ",predict=" str(prediction_result[i]) ,fontsize=9) plt.show()
代码语言:javascript复制show_images_labels_predict(mnist.test.images,mnist.test.labels,prediction_result)
找出预测错误
代码语言:javascript复制for i in range(500): if prediction_result[i]!=np.argmax(mnist.test.labels[i]): print("i=" str(i) " label=",np.argmax(mnist.test.labels[i]), "predict=",prediction_result[i])
代码语言:javascript复制i=247 label= 4 predict= 2i=259 label= 6 predict= 0i=290 label= 8 predict= 4i=320 label= 9 predict= 1i=321 label= 2 predict= 7i=340 label= 5 predict= 3i=445 label= 6 predict= 0i=495 label= 8 predict= 0
代码语言:javascript复制def show_images_labels_predict_error(images,labels,prediction_result): fig = plt.gcf() fig.set_size_inches(8, 10) i=0;j=0 while i<10: if prediction_result[j]!=np.argmax(labels[j]): ax=plt.subplot(5,5, 1 i) ax.imshow(np.reshape(images[j],(28, 28)), cmap='binary') ax.set_title("j=" str(j) ",l=" str(np.argmax(labels[j])) ",p=" str(prediction_result[j]) ,fontsize=9) i=i 1 j=j 1 plt.show()
代码语言:javascript复制show_images_labels_predict_error(mnist.test.images,mnist.test.labels,prediction_result)
保存模型
代码语言:javascript复制saver = tf.train.Saver()
代码语言:javascript复制save_path = saver.save(sess, "saveModel/CNN_model1")
代码语言:javascript复制print("Model saved in file: %s" % save_path)
代码语言:javascript复制Model saved in file: saveModel/CNN_model1
启动TensorBoard
- tensorboard --logdir=c:pyhonworklogCNN
- 在浏览器中打开https://localhost:6006/
merged = tf.summary.merge_all()train_writer = tf.summary.FileWriter('log/CNN',sess.graph)
代码语言:javascript复制sess.close()