本篇文章在上篇TensorFlow-手写数字识别(一)的基础上进行改进,主要实现以下3点:
- 断点续训
- 测试真实图片
- 制作TFRecords格式数据集
断点续训
上次的代码每次进行模型训练时,都会重新开始进行训练,之前的训练结果都被覆盖掉了,极不方便。
在backwork.py中加入ckpt操作,可以实现断点续训功能。
代码实现
代码语言:javascript复制with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
for i in range(STEPS):
xs, ys = sess.run([img_batch, label_batch])
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 100 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
注解:
tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
该函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。checkpoint_dir
:表示存储断点文件的目录latest_filename=None
:断点文件的可选名称,默认为“checkpoint”
saver.restore(sess, ckpt.model_checkpoint_path)
该函数表示恢复当前会话,将ckpt中的值赋给w和b。sess
:表示当前会话,之前保存的结果将被加载入这个会话ckpt.model_checkpoint_path
:表示模型存储的位置,不需要提供模型的名字,它会去查看 checkpoint 文件,看看最新的是谁,叫做什么。
代码验证
代码运行效果:
代码语言:javascript复制 RESTART: G:TestProjectpythontensorflow...mnist_backward.py
After 16203 training step(s), loss on training batch is 0.155758.
After 16303 training step(s), loss on training batch is 0.173135.
After 16403 training step(s), loss on training batch is 0.159716.
可以看出,程序可以接着之前的训练数据接着训练
输入真实图片,输出预测结果
上次的代码只能使用MNIST自带数据集中的数据进行训练,这次通过编写mnist_app.py函数,实现真实图片数据的预测。
分析输入输出情况:
- 网络输入:一维数组(784 个像素点)
- 像素点:0-1 之间的浮点数(接近0越黑,接近1越白)
- 网络输出:一维数组(十个可能性概率),数组中最大的那个元素所对应的索引号就是预测的结果
定义输入图片接口函数
代码语言:javascript复制def application():
testNum = input("input the number of test pictures:")
for i in range(testNum):
testPic = raw_input("the path of test picture:")
testPicArr = pre_pic(testPic)
preValue = restore_model(testPicArr)
print "The prediction number is:", preValue
任务分两个函数完成
- testPicArr = pre_pic(testPic)对手写数字图片做预处理
- preValue = restore_model(testPicArr) 将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值
具体代码实现:
- 图片预处理函数
#预处理函数,包括resize、转变灰度图、二值化操作
def pre_pic(picName):
img = Image.open(picName) #加载待测试图片(白底)
reIm = img.resize((28,28), Image.ANTIALIAS) #调整大小到28x28
im_arr = np.array(reIm.convert('L'))
threshold = 50 #二进制阈值
for i in range(28):
for j in range(28):
im_arr[i][j] = 255 - im_arr[i][j] #反色(黑底)
if (im_arr[i][j] < threshold): #黑底白字
im_arr[i][j] = 0
else:
im_arr[i][j] = 255
nm_arr = im_arr.reshape([1, 784]) #图片转成1行
nm_arr = nm_arr.astype(np.float32)
img_ready = np.multiply(nm_arr, 1.0/255.0) #取值范围限制在0~1之间
return img_ready
- 获取训练参数进行预测函数
def restore_model(testPicArr):
with tf.Graph().as_default() as tg:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y = mnist_forward.forward(x, None)
preValue = tf.argmax(y, 1)
variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict={x:testPicArr})
return preValue
else:
print("No checkpoint file found")
return -1
注解:
1)main 函数中调用的application()
函数:输入要识别的几张图片(注意要给出待识别图片的路径和名称)。
2)代码处理过程:
- 模型的要求是黑底白字,但输入的图是白底黑字,所以需要对每个像素点的值改为 255 减去原值以得到互补的反色。
- 对图片做二值化处理(这样以滤掉噪声,另外调试中可适当调节阈值)。
- 把图片形状拉成1行784列,并把值变为浮点型(因为要求像素点是 0-1之间的浮点数)。
- 接着让现有的 RGB 图从0-255之间的数变为 0-1 之间的浮点数。
- 运行完成后返回到 main 函数。
- 计算求得输出 y,y的最大值所对应的列表索引号就是预测结果。
代码验证
1)运行 mnist_backward.py 首先对模型进行训练
代码语言:javascript复制 RESTART: G:TestProjectpythontensorflow...mnist_backward.py
After 16203 training step(s), loss on training batch is 0.155758.
After 16303 training step(s), loss on training batch is 0.173135.
After 16403 training step(s), loss on training batch is 0.159716.
2)运行 mnist_test.py 使用测试集,监测模型的准确率
代码语言:javascript复制 RESTART: G:TestProjectpythontensorflow...mnist_test.py
After 16703 training step(s), test accuracy = 0.9798
3)运行 mnist_app.py 输入1~10之间的数(表示循环验证的图片数量)
代码语言:javascript复制 RESTART: G:TestProjectpythontensorflow...mnist_app.py
input the number of test pictures:5
the path of test picture:pic .png
The prediction number is: [0]
the path of test picture:pic1.png
The prediction number is: [3]
the path of test picture:pic5.png
The prediction number is: [5]
the path of test picture:pic8.png
The prediction number is: [8]
the path of test picture:pic9.png
The prediction number is: [9]
>>>
制作数据集,实现特定应用
上次的程序使用的MNIST整理好的特定格式的数据,如果想要用自己的图片进行模型训练,就需要自己制作数据集。
数据集的制作的不仅仅是将图片整理在一起,通过转换成特定的格式,可以加速图片读取的效率。
下面将MNIST数据集转换成tfrecords格式,该方法也可以将普通图片转换为该格式。
编写数据集生成读取文件(mnist_ generateds.py)
tfrecords文件
tfrecords
:一种二进制文件,可先将图片和标签制作成该格式的文件,使用tfrecords进行数据读取会提高内存利用率tf.train.Example
:用来存储训练数据,训练数据的特征用键值对的形式表示SerializeToString( )
:把数据序列化成字符串存储
生成tfrecords文件
读取原始图片和标签文件,转换为tfrecord格式
代码语言:javascript复制def write_tfRecord(tfRecordName, image_path, label_path):
writer = tf.python_io.TFRecordWriter(tfRecordName) #新建一个writer
num_pic = 0
f = open(label_path, 'r') #打开标签文件
contents = f.readlines() #读入(格式如:2028_7.jpg 7)
f.close()
for content in contents: #遍历每张图片和对应标签
value = content.split() #拆分:图片名 对应标签
img_path = image_path value[0]
img = Image.open(img_path) #打开对应的图片文件
img_raw = img.tobytes()
labels = [0] * 10
labels[int(value[1])] = 1
#把每张图片和标签封装到 example 中
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
writer.write(example.SerializeToString()) #把example进行序列化
num_pic = 1
print ("the number of picture:", num_pic)
writer.close() #关闭writer
print("write tfrecord successful")
注解:
writer = tf.python_io.TFRecordWriter( tfRecordName)
:新建一个 writerfor循环
:遍历每张图和标签writer.write(example.SerializeToString())
:把 example 进行序列化writer.close()
:关闭 writer
保存tfrecord格式文件
代码语言:javascript复制def generate_tfRecord():
isExists = os.path.exists(data_path) #检查用于存放数据集的路径是否存在
if not isExists:
os.makedirs(data_path)
print('The directory was created successfully')
else:
print('directory already exists')
write_tfRecord(tfRecord_train, image_train_path, label_train_path)
write_tfRecord(tfRecord_test, image_test_path, label_test_path)
解析 tfrecords 文件
获取tfrecords文件接口函数
代码语言:javascript复制def get_tfrecord(num, isTrain=True):
if isTrain:
tfRecord_path = tfRecord_train
else:
tfRecord_path = tfRecord_test
img, label = read_tfRecord(tfRecord_path)
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size = num,
num_threads = 2,
capacity = 1000,
min_after_dequeue = 700)
#返回的图片和标签为随机抽取的 batch_size 组
return img_batch, label_batch
注解:
代码语言:javascript复制tf.train.shuffle_batch(tensors, batch_size, capacity,
min_after_dequeue, num_threads=1, seed=None,
enqueue_many=False, shapes=None, allow_smaller_final_batch=False,
shared_name=None, name=None)
tensors
: 待乱序处理的列表中的样本(图像和标签)batch_size
: 从队列中提取的新批量大小capacity
:队列中元素的最大数量min_after_dequeue
: 出队后队列中的最小数量元素,用于确保元素的混合级别num_threads
: 排列 tensors 的线程数seed
:用于队列内的随机洗牌enqueue_many
: tensor 中的每个张量是否是一个例子shapes
: 每个示例的形状allow_smaller_final_batch
: (可选)布尔值。如果为 True,则在队列中剩余数量不足时允许最终批次更小。shared_name
:(可选)如果设置,该队列将在多个会话中以给定名称共享。name
:操作的名称(可选)
读取tfrecords文件
代码语言:javascript复制def read_tfRecord(tfRecord_path):
filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
reader = tf.TFRecordReader() # 新建一个reader
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([10], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)# 将img_raw字符串转换为8位无符号整型
img.set_shape([784])# 将形状变为一行784列
img = tf.cast(img, tf.float32) * (1./255)# 变成0到1之间的浮点数
label = tf.cast(features['label'], tf.float32) # 把标签列表变为浮点数
return img, label # 返回图片和标签(跳回到 get_tfrecord)
注解:
代码语言:javascript复制tf.train.string_input_producer( string_tensor, num_epochs=None,
shuffle=True,seed=None,capacity=32,
shared_name=None,name=None,cancel_op=None)
该函数会生成一个先入先出的队列,文件阅读器会使用它来读取数据
string_tensor
: 存储图像和标签信息的 TFRecord 文件名列表num_epochs
: 循环读取的轮数(可选)shuffle
:布尔值(可选),如果为 True,则在每轮随机打乱读取顺序seed
:随机读取时设置的种子(可选)capacity
:设置队列容量shared_name
:(可选) 如果设置,该队列将在多个会话中以给定名称共享。 所有具有此队列的设备都可以通过 shared_name 访问它。在分布式设置中使用这种方法意味着每个名称只能被访问此操作的其中一个会话看到。name
:操作的名称(可选)cancel_op
:取消队列(None)
_, serialized_example = reader.read(filename_queue)
把读出的每个样本保存在 serialized_example 中进行解序列化,标签和图片的键名应该和制作 tfrecords 的键名相同,其中标签给出几分类
tf.parse_single_example(serialized,features,name=None,example_names=None)
该函数可以将 tf.train.Example 协议内存块(protocol buffer)解析为张量。
serialized
: 一个标量字符串张量features
: 一个字典映射功能键 FixedLenFeature 或 VarLenFeature值,也就是在协议内存块中储存的name
:操作的名称(可选)example_names
: 标量字符串联的名称(可选)
反向传播文件修改图片标签获取的接口( mnist_backward .py)
利用多线程提高图片和标签的批获取效率
将批获取的操作放到线程协调器开启和关闭之间
- 开启线程协调器:
coord = tf.train.Coordinator( )
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
- 关闭线程协调器:
coord.request_stop( )
coord.join(threads)
注解:
代码语言:javascript复制tf.train.start_queue_runners( sess=None, coord=None, daemon=True,
start=True, collection=tf.GraphKeys.QUEUE_RUNNERS)
这个函数将会启动输入队列的线程,填充训练样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个tf.train.Coordinator
,这样可以在发生错误的情况下正确地关闭这些线程。
sess
:用于运行队列操作的会话。默认为默认会话coord
:可选协调器,用于协调启动的线程daemon
: 守护进程,线程是否应该标记为守护进程,这意味着它们不会阻止程序退出start
:设置为 False 只创建线程,不启动它们collection
:指定图集合以获取启动队列的 GraphKey,默认为GraphKeys.QUEUE_RUNNERS
修改后的mnist_backward.py的关键部分:
代码语言:javascript复制...
import mnist_generateds#【1】
...
train_num_examples = 60000#【2】 训练集图片的个数
def backward():
...
saver = tf.train.Saver()
img_batch, label_batch = mnist_generateds.get_tfrecord(BATCH_SIZE, isTrain=True)#【3】
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
coord = tf.train.Coordinator()#【4】 开启线程协调器
threads = tf.train.start_queue_runners(sess=sess, coord=coord)#【5】
for i in range(STEPS):
xs, ys = sess.run([img_batch, label_batch])#【6】
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 100 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
coord.request_stop()#【7】 关闭线程协调器
coord.join(threads)#【8】
注解:
train_num_examples=60000
在梯度下降学习率中需要计算多少轮更新一次学习率,这个值是:总样本数/batch size
- 之前:用mnist.train.num_examples表示总样本数;
- 现在:手动给出训练的总样本数,这个数是6万。
image_batch, label_batch=mnist_generateds.get_tfrecord(BATCH_SIZE,isTrain=True)
- 之前:用mnist.train.next_batch函数读出图片和标签喂给网络;
- 现在:用函数get_tfrecord替换,一次批获取batch_size张图片和标签。
- isTrain:用来区分训练阶段和测试阶段,True 表示**训练**,False表示**测试**。
xs,ys=sess.run([img_batch,label_batch])
- 之前:使用函数xs,ys=mnist.train.next_batch(BATCH_SIZE)
- 现在:在sess.run中执行图片和标签的批获取。
测试文件修改图片标签获取的接口(mnist_test.py)
修改后的mnist_test.py的关键部分:
代码语言:javascript复制#coding:utf-8
...
TEST_NUM = 10000#【1】
def test():
with tf.Graph().as_default() as g:
...
img_batch, label_batch = mnist_generateds.get_tfrecord(TEST_NUM, isTrain=False)#【2】
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
coord = tf.train.Coordinator()#【3】
threads = tf.train.start_queue_runners(sess=sess, coord=coord)#【4】
xs, ys = sess.run([img_batch, label_batch])#【5】
accuracy_score = sess.run(accuracy, feed_dict={x: xs, y_: ys})
print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))
coord.request_stop()#【6】
coord.join(threads)#【7】
else:
print('No checkpoint file found')
return
time.sleep(TEST_INTERVAL_SECS)
注解:
TEST_NUM=10000
- 之前:用 mnist.test.num_examples 表示总样本数
- 现在:手动给出测试的总样本数,这个数是1万
image_batch, label_batch=mnist_generateds.get_tfrecord(TEST_NUM,isTrain=False)
- 之前:用 mnist.test.next_batch 函数读出图片和标签喂给网络
- 现在:用函数 get_tfrecord 替换读取所有测试集 1 万张图片
- isTrain:用来区分训练阶段和测试阶段,True 表示训练,False 表示测试
xs,ys=sess.run([img_batch,label_batch])
- 之前:使用函数 xs,ys=mnist.test.next_batch(BATCH_SIZE)
- 现在:在 sess.run 中执行图片和标签的批获取
代码验证
运行测试代码 mnist_test.py
代码语言:javascript复制 RESTART: G:TestProjectpythontensorflow...mnist_test.py
After 16703 training step(s), test accuracy = 0.9794
After 16703 training step(s), test accuracy = 0.9797
After 16703 training step(s), test accuracy = 0.9795
After 16703 training step(s), test accuracy = 0.9792
运行测试代码 mnist_app.py
代码语言:javascript复制 RESTART: G:TestProjectpythontensorflow...mnist_app.py
input the number of test pictures:5
the path of test picture:pic .png
The prediction number is: [0]
the path of test picture:pic1.png
The prediction number is: [3]
the path of test picture:pic5.png
The prediction number is: [5]
the path of test picture:pic8.png
The prediction number is: [8]
the path of test picture:pic9.png
The prediction number is: [9]
>>>
可以看出和之前的结果一样,代码可用。
注:以上测试图片用的是下面教程中自带的图片,测试结果100%准确,我自己用Windows画图板手写了0~9的数字,准确度只有50%左右,可能是我手写字体和MNIST库中的风格差异较大,或是目前的网络还不够好,下一篇通过搭建CNN网络继续测试。
参考:人工智能实践:Tensorflow笔记