卷积神经网络处理图像识别(三)

2019-11-25 16:47:41 浏览数 (1)

本篇接着上一篇来介绍卷积神经网络的训练(即反向传播)和应用。

训练神经网络和保存训练结果的代码如下:

代码语言:javascript复制
import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import os
import numpy as np
import CNN_MNIST_inference

MODEL_SAVE_PATH ="E:/Python36/my tensorflow/CNN/model_path/"
MODEL_NAME = "MNIST_CNNmodel.ckpt"
print(os.path.join(MODEL_SAVE_PATH, MODEL_NAME))
BATCH_SIZE  =100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
MOVING_AVERAGE_DECAY = 0.99
TRAINING_STEPS = 20000

def train(mnist):
    '''training'''
    x = tf.placeholder(tf.float32,
                       [None,
                        CNN_MNIST_inference.IMAGE_HEIGHT,
                        CNN_MNIST_inference.IMAGE_WIDTH,
                        CNN_MNIST_inference.NUM_CHANNELS], name='x-input')
    y_ = tf.placeholder(tf.float32, [None, CNN_MNIST_inference.OUTPUT_NODE], name = 'y-input')
    #I2 正则
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    y = CNN_MNIST_inference.inference(x, True, regularizer, None, reuse = False)
    global_step = tf.Variable(0, trainable = False)
    #平均移动
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables()) # moving average applied
    average_y = CNN_MNIST_inference.inference(x, True, regularizer,variable_averages, reuse = True)
    
    # loss
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, labels = tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    tf.add_to_collection('losses', cross_entropy_mean)
    loss = tf.add_n(tf.get_collection('losses'))
    #loss = cross_entropy_mean
    
    #learning rate with decay
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY, staircase = True)
    #learning_rate = 0.01
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)
    train_op = tf.group(train_step, variables_averages_op)
    correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    saver = tf.train.Saver() #初始化持久类
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run() # 真正变量初始化
        
        validation_set  = np.reshape(mnist.validation.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        validate_feed  = {x: validation_set, y_ : mnist.validation.labels} #验证集

        test_set  = np.reshape(mnist.test.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        test_feed        = {x: test_set, y_ : mnist.test.labels} #测试集(训练集)
        
        steps = [] # only for plot
        accs = [] # only for plot
        losses = [] # only for plot
        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            xs = np.reshape(xs,
                            [BATCH_SIZE,
                             CNN_MNIST_inference.IMAGE_HEIGHT,
                             CNN_MNIST_inference.IMAGE_WIDTH,
                             CNN_MNIST_inference.NUM_CHANNELS])
                                
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict = {x : xs, y_: ys})

            #print(i,loss_value)
            if i % 25  == 0:
                validate_acc = sess.run(accuracy, feed_dict = validate_feed) #验证集 准确度
                steps.append(step); accs.append(validate_acc*100); losses.append(loss_value) # only for plot
                print("After %d training steps, validation dataset accuracy after this batch is %g%%, test dataset loss on this batch is %g"%(step, validate_acc*100,loss_value))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)
                
        saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step =global_step)    
        test_acc = sess.run(accuracy, feed_dict = test_feed)
        print("After %d training steps, test accuracy using average model is %g%%"%
              (TRAINING_STEPS, test_acc*100))
        writer = tf.summary.FileWriter("E://TensorBoard//test",sess.graph)
        
        saver.save(sess, r"E:Python36my tensorflowckpt filesmode_mnist.ckpt")
    #only for plot
    from matplotlib import pyplot as plt
    import matplotlib.ticker as mtick
    plt.subplot(211)
    plt.plot(steps, losses,color="red")
    plt.scatter(steps, losses,s=20,color="red")
    plt.xlabel("训练的步数(Batch数)"); plt.ylabel("训练batch上的Loss(含L2正则Loss)")
    plt.subplot(212)
    plt.plot(steps, accs,color="green")
    plt.scatter(steps, accs,s=20,color="green")
    yticks = mtick.FormatStrFormatter("%.3f%%")
    plt.gca().yaxis.set_major_formatter(yticks)
    plt.xlabel("step"); plt.ylabel("验证集上的预测准确率")
    plt.show()
 
def main(argv = None):
    mnist = input_data.read_data_sets(r"E:Python36my tensorflowMNIST_data",one_hot =True)
    train(mnist)

if __name__ == "__main__":
    tf.app.run() #调用main()

下面是测试Batch的总Loss和验证集上的准确率的收敛趋势图。由于我的电脑性能不好,所以我大幅度削减了待训练参数个数。尽管如此,2000轮训练之后,在验证集上5000个图片的预测正确率已达98.3%。如若不削减参数,准确率可达99.4%。

下面的代码是利用训练好的卷积神经网络模型来评估它在验证集上的准确率(可以在正式训练时不评估从而节省训练时间),以及用它用来识别单张图片。

代码语言:javascript复制
import  tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import os
import numpy as np
import CNN_MNIST_inference
import CNN_MNIST_train
import matplotlib.pyplot as plt

def evaluate(mnist):   #评估验证集的预测准确度
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32,
                           [None,
                            CNN_MNIST_inference.IMAGE_HEIGHT,
                            CNN_MNIST_inference.IMAGE_WIDTH,
                            CNN_MNIST_inference.NUM_CHANNELS], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, CNN_MNIST_inference.OUTPUT_NODE], name = 'y-input')
        validation_set  = np.reshape(mnist.validation.images,
                                     [-1,
                                      CNN_MNIST_inference.IMAGE_HEIGHT,
                                      CNN_MNIST_inference.IMAGE_WIDTH,
                                      CNN_MNIST_inference.NUM_CHANNELS])
        validate_feed  = {x: validation_set, y_ : mnist.validation.labels} #验证集
        
        y = CNN_MNIST_inference.inference(x, False, None, None, reuse = False)
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            #平均移动
        variable_averages = tf.train.ExponentialMovingAverage(CNN_MNIST_train.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
        with tf.Session() as sess:
            #print(CNN_MNIST_train.MODEL_SAVE_PATH)
            #找到目录中最新的模型文件
            ckpt = tf.train.get_checkpoint_state(CNN_MNIST_train.MODEL_SAVE_PATH)
            #print(ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                #加载模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                #模型的迭代轮数
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split("-")[-1]
                accuary_score = sess.run(accuracy, feed_dict =validate_feed)
                print("After %s training steps, validation accuary = %g" %(global_step, accuary_score)) #global_step是str
            else:
                print('No checkpoint file found')
                return
                
 #把所有输入数据input_data、声明的常量放进with tf.Graph().as_default(): 里面就行了,就可以统一到同一个graph了,
#不然input_data是放到系统默认创建的Graph,跟你又重新with tf.Graph().as_default():不是同一个Graph()就会报错           
def recognize(input_x):
    g = tf.get_default_graph() # 因为 input_x 默认的图中,所以可把下面的计算也默认的图中
    with g.as_default():
        y = CNN_MNIST_inference.inference(input_x, False, None, None, reuse = False)
        saver = tf.train.Saver()
        with tf.Session() as sess:
            #找到目录中最新的模型文件
            ckpt = tf.train.get_checkpoint_state(CNN_MNIST_train.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                #加载模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                predicted_label = tf.argmax(y, 1)
                print("predicted_label: ", sess.run(predicted_label)[0])
            else:
                print('No checkpoint file found')
                return
                
def plotImage(path):#仅用于绘制待识别的图片
    image_rawdata = tf.gfile.FastGFile(path,"rb").read()
    img_data = tf.image.decode_jpeg(image_rawdata)
    if img_data.dtype != tf.float32:
        img_data = tf.image.convert_image_dtype(img_data, dtype = tf.float32)
    with tf.Session() as sess:
        image_data = img_data.eval() # return a numpy array#需要运行在会话中
    image_data_shaped1 = image_data.reshape(image_data.shape[0],image_data.shape[1])#numpy array
    #print(image_data_shaped1)
    plt.imshow(image_data_shaped1,cmap='gray')
    plt.show()
    
def main(argv=None): 
    mnist = input_data.read_data_sets(r"E:Python36my tensorflowMNIST_data",one_hot =True)
    evaluate(mnist) #评估在验证集上的预测准确度
    #输入
    image_path = r"E:Python36MNIST picturetest50.jpg"
    image_rawdata = tf.gfile.FastGFile(image_path,"rb").read()
    img_data0 = tf.image.decode_jpeg(image_rawdata)
    if img_data0.dtype != tf.float32:
        img_data = tf.image.convert_image_dtype(img_data0, dtype = tf.float32)
    
    #根据神经网络的要求转换图片数据的shape!    
    input_x =  tf.reshape(img_data, [1,
                                    CNN_MNIST_inference.IMAGE_HEIGHT,
                                    CNN_MNIST_inference.IMAGE_WIDTH,
                                    CNN_MNIST_inference.NUM_CHANNELS])
    plotImage(image_path)
    recognize(input_x)

if __name__ =="__main__":
    #tf.app.run() #调用main()
    main()#

0 人点赞