知识之窗
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。
2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络。
PyTorch的前身是Torch,其底层和Torch框架一样,但是使用Python重新写了很多内容,不仅更加灵活,支持动态图,而且提供了Python接口。它是由Torch7团队开发,是一个以Python优先的深度学习框架,不仅能够实现强大的GPU加速,同时还支持动态神经网络。
PyTorch既可以看作加入了GPU支持的numpy,同时也可以看成一个拥有自动求导功能的强大的深度神经网络。除了Facebook外,它已经被Twitter、CMU和Salesforce等机构采用。
回顾
在上周的文章中, 我们学习了整合所有的代码(数据预处理,网络模型,训练代码),然后进行了实际的训练,我们必须知道,神经网络的训练结果小除了知道模型的好坏以及有效性以外,我们还需要考虑将训练好的模型进行实际的测试,也需要后期需要用来部署为应用也说不定,当然不会直接部署,还需呀考虑优化,压缩,剪枝等问题。
一、模型预测
实现步骤:
1.在训练过程中保存模型
2.编写测试代码(数据处理,模型调用,数据测试)
4.输出模型结果,映射为真实标签
1.训练过程中保存模型
代码语言:javascript复制#在训练之前添加
# 产生一个saver来存储训练好的模型
saver = tf.train.Saver()
在每训练一个batch后,开始整个验证集的测试(现在一般训练一个epoch后,才进行验证),验证集测试后,如果大于上一次的测试准确率并且大于80%才考虑保存模型,即最终保存最好的模型。
代码语言:javascript复制 if avg_test_acc > pre_test_acc and avg_test_acc > 0.80:
checkpoint_path = os.path.join(logs_checkpoint,
'model.ckpt')
saver.save(sess,
2.测试代码
1.数据预处理:
这个地方与训练的时候一样
代码语言:javascript复制# 获取一张图片
def get_one_image(img_dir):
# 输入参数:train,训练图片的路径
# 返回参数:image,从训练图片中随机抽取一张图片
#print("train", train)
#n = len(train)
#ind = np.random.randint(0, n)
#img_dir = train[ind] # 随机选择测试的图片
# img_dir = train
img = Image.open(img_dir)
#plt.imshow(img)
#imag = img.resize([150, 150]) # 由于图片在预处理阶段以及resize,因此该命令可略
imge = tf.image.resize_images(img, (150, 150))
image = tf.reshape(imge, [1, 150, 150, 3])
#image = np.array(imge)
image = image/255
image = tf.cast(image, tf.float32)
return image
2.模型调用
其实就是回复保存模型的参数后导入到现在的网络中,进行测试。
现在的网络只进行前向传播,不进行反向传播。
代码语言:javascript复制saver = tf.train.Saver()
with tf.Session() as sess:
img_array = sess.run(image_array)
print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading success, global_step is %s' % global_step)
else:
3.数据测试
代码语言:javascript复制# 测试图片
def evaluate_one_image(image_array):
global graph
graph = tf.get_default_graph()
with graph.as_default():
BATCH_SIZE = 1
N_CLASSES = 2
#image = tf.cast(image_array, tf.float32)
x = tf.placeholder(tf.float32, shape=[1,150, 150, 3])
logit = model.inference(x, BATCH_SIZE, N_CLASSES,1)
logit = tf.nn.softmax(logit)
4.输出结果:
代码语言:javascript复制prediction = sess.run(logit,feed_dict={x: img_array})
max_index = np.argmax(prediction)
# print(max_index)
# 标签映射可以选择字典或者列表
label_dict = {0: 'cat', 1: 'dog'}
label_list = ['cat','dog']
print("模型的输出为{},对应的真实标签为:{}".format(max_index,label_list[max_index]))
全部的测试代码:
实际预测展示
可以看到我们读取的是测试中的dog的图片,随后网络的预测标签是1,当初给dog的标签为1,即映射实际标签为dog,预测正确。
结语
本次分享结束了,算是图像分类项目的一个完整流程的项目,从数据处理到网络搭建,到训练,到调用模型做预测,我们都进行了分享,同时对代码细节进行了注释,相信聪敏的你一定可以看懂,如有疑惑请随时后台哦。
虽然本次项目结束,但我相信,其中或多或少有些地方大家不太理解,不管数据处理还是网络的搭建等等都或许不是那么简单,没关系,下次,小编会针对本次项目中的漏洞进行一个总结,算是图像分类项目的总结篇吧,同时也欢迎各位老铁,多多提问,以促使我们一起进步。
周末愉快,我们下期再见!
编辑:玥怡居士|审核:小圈圈居士