是骡子是马拉出来溜溜就知道,一个模型好还是坏,放在全新的测试集上去测试下就知道了,根据模型测试的结果我们才能衡量模型的泛化性、稳定性等指标如何,从而方便我们根据测试的反馈去进行调参优化模型。
这里我是根据kaggle比赛来写的模型测试代码,所以可能跟实际的工程项目有所差别,注意区分。
这里的模型测试程序,是我参加dogs-vs-cats-redux-kernels-edition比赛而编写的,其他Kaggle比赛有所区别,但大致逻辑和流程没有差别。
模型测试及输出结果程序实现
下面的程序中,我只是加载了模型中每一个变量即权重参数的取值,没有加载模型中定义好的变量,对输入和输出我都重新定义了,其实是可以通过以下代码直接返回训练好的模型中设置的输入输出变量的:
代码语言:javascript复制# 加载模型
saver = tf.train.import_meta_graph('F:/Software/Python_Project/Classification-cat-dog/logs/model.ckpt-20000.meta')
graph = tf.get_default_graph()
# 返回训练模型中设置的输入张量
x = graph.get_tensor_by_name("x:0")
# 返回训练模型中设置的输出张量
logits = graph.get_tensor_by_name("logits_eval:0")
但是,因为我之前迭代训练模型的程序中,并不是通过设置placeholder占位符x输入到神经网络中去的,所以如果直接返回训练好的模型中设置的输入输出变量,我感觉会出现点问题,所以就没有那样编写程序。
写到这里,我真的觉得TensorFlow的坑真的很多,就算彻底掌握python,但是如果没有深入研究过TensorFlow的话,还是容易掉坑,但是在工业界TensorFlow是使用最广泛的机器学习框架,我们还是有必要去深入学习和掌握这个框架,只能说告诫初学者(虽然我也是初学者),如果学了一段时间TensorFlow还是遇到各种问题或者没有掌握的话,可以去试试Keras或者Pytorch,毕竟它们上手真的更简单。
代码如下:
代码语言:javascript复制# 评估模型
# coding:utf-8
# filename:catdog_test.py
# Environment:windows10,python3,numpy,TensorFlow1.9,glob,skimage,numpy,
# Function:负责测试猫狗识别网络模型,并将识别结果输出到csv文件中
from PIL import Image
import matplotlib.pyplot as plt
import os
import numpy as np
import tensorflow as tf
import input_data
import model
from skimage import io,transform
import pandas as pd
import csv
# ---------------------------配置神经网络超参数-------------------------------------------
N_CLASSES = 2 # 输出类别数
IMG_W = 227 # 图像宽度
IMG_H = 227 # 图像高度
IMG_C = 3 # 图像通道
BATCH_SIZE = 1 # 批次大小
# ---------------------------读取测试集数据------------------------------------------------
# 获取指定目录下的文件名
def get_file(dataset_dir):
photo_filenames = []
i = 0
for i,filename in enumerate(os.listdir(dataset_dir)):
# 获取文件路径
path = os.path.join(dataset_dir,filename)
photo_filenames.append(path) # list
i = 1
# 返回(图片)文件数量及列表形式的文件名(包含路径)
return i,photo_filenames
# 根据外部路径读取一张照片
def get_one_image(path):
img = io.imread(path)
img = transform.resize(img, (IMG_W,IMG_H))
# 返回图像数据(3*D)
return np.array(img) # ndarray
# --------------------------测试模型,循环输出测试结果--------------------------------------
def run_testing():
file_names = []
file_labels = []
list = []
# 云服务器训练对应路径地址
# train_dir = '/data/Dogs-Cats-Redux-Kernels-Edition/train/'
# logs_train_dir = '/data/Dogs-Cats-Redux-Kernels-Edition/logs/'
# 本地电脑训练对应路径地址
train_dir = "F:/Software/Python_Project/Classification-cat-dog/test/"
logs_train_dir = "F:/Software/Python_Project/Classification-cat-dog/logs/"
_,test_files = get_file(train_dir)
for i,file in enumerate(test_files):
i = 1
file_name = file.split('/')[-1].split(sep='.')[0]
image_array = get_one_image(file)
with tf.Graph().as_default():
image = tf.cast(image_array, tf.float32)
image = tf.reshape(image, [1, 208, 208, 3])
# file_names.append(file_name)
logit = model.inference(image, BATCH_SIZE, N_CLASSES)
# 对神经网络输出进行softmax回归操作,使得输出变成一个概率分布
logit = tf.nn.softmax(logit)
x = tf.placeholder(tf.float32, shape=[208, 208, 3])
saver = tf.train.Saver()
# 初始化会话并开始预测过程
with tf.Session() as sess:
print("Reading checkpoints...")
# 找到指定目录中最新模型的文件名
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path: # 模型文件名和模型路径存在,则进行下一步
# 通过文件名得到模型保存时迭代的轮数
global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
# 加载模型
saver.restore(sess, ckpt.model_checkpoint_path)
print("Loading success, global_step is %s" % global_step)
else:
print("No checkpoint file found")
prediction = sess.run(logit, feed_dict={x: image_array})
max_index = np.argmax(prediction)
a = [file_name,max_index]
list.append(a)
print(prediction.shape)
if max_index == 0:
print(file_name,"file is a cat with possibility %.6f" % prediction[:, 0])
else:
print(file_name,"file is a dog with possibility %.6f" % prediction[:, 1])
print('finished recognition %d file' % i)
# 将测试结果写入sample_submission.csv
test = pd.DataFrame(data=list) # 数据有2列
# print(test)
test.to_csv('sample_submission.csv')
print('Write to csv file finished')
print(list) # 打印测试结果,list数据类型
# -------------------------------程序从这里开始运行---------------------------------------
if __name__ == "__main__":
run_testing()
输出结果
输出一个csv文件,如下图所示:
注意,我这里程序得到的结果是乱序的,也没有id及label这一行,可通过Excel文件手动添加和排序。