深度学习实战篇之 ( 十) -- TensorFlow学习之路(七)

2022-06-01 20:11:16 浏览数 (1)

知识之窗

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。

2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络。

PyTorch的前身是Torch,其底层和Torch框架一样,但是使用Python重新写了很多内容,不仅更加灵活,支持动态图,而且提供了Python接口。它是由Torch7团队开发,是一个以Python优先的深度学习框架,不仅能够实现强大的GPU加速,同时还支持动态神经网络。

PyTorch既可以看作加入了GPU支持的numpy,同时也可以看成一个拥有自动求导功能的强大的深度神经网络。除了Facebook外,它已经被Twitter、CMU和Salesforce等机构采用。

回顾

在上周的文章中, 我们学习了整合所有的代码(数据预处理,网络模型,训练代码),然后进行了实际的训练,我们必须知道,神经网络的训练结果小除了知道模型的好坏以及有效性以外,我们还需要考虑将训练好的模型进行实际的测试,也需要后期需要用来部署为应用也说不定,当然不会直接部署,还需呀考虑优化,压缩,剪枝等问题。

一、模型预测

实现步骤:

1.在训练过程中保存模型

2.编写测试代码(数据处理,模型调用,数据测试)

4.输出模型结果,映射为真实标签

1.训练过程中保存模型

代码语言:javascript复制
#在训练之前添加
# 产生一个saver来存储训练好的模型
saver = tf.train.Saver()

在每训练一个batch后,开始整个验证集的测试(现在一般训练一个epoch后,才进行验证),验证集测试后,如果大于上一次的测试准确率并且大于80%才考虑保存模型,即最终保存最好的模型。

代码语言:javascript复制
 if avg_test_acc > pre_test_acc and avg_test_acc > 0.80:
checkpoint_path = os.path.join(logs_checkpoint,
 'model.ckpt')
saver.save(sess,

2.测试代码

1.数据预处理:

这个地方与训练的时候一样

代码语言:javascript复制
# 获取一张图片
def get_one_image(img_dir):
    # 输入参数:train,训练图片的路径
    # 返回参数:image,从训练图片中随机抽取一张图片
    #print("train", train)
    #n = len(train)
    #ind = np.random.randint(0, n)
    #img_dir = train[ind]  # 随机选择测试的图片
    # img_dir = train

    img = Image.open(img_dir)
    #plt.imshow(img)
    #imag = img.resize([150, 150])  # 由于图片在预处理阶段以及resize,因此该命令可略
    imge = tf.image.resize_images(img, (150, 150))
    image = tf.reshape(imge, [1, 150, 150, 3])
    #image = np.array(imge)

    image = image/255
    image = tf.cast(image, tf.float32)

    return image

2.模型调用

其实就是回复保存模型的参数后导入到现在的网络中,进行测试。

现在的网络只进行前向传播,不进行反向传播。

代码语言:javascript复制
saver = tf.train.Saver()

with tf.Session() as sess:
img_array = sess.run(image_array)

print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    saver.restore(sess, ckpt.model_checkpoint_path)
    print('Loading success, global_step is %s' % global_step)
else:

3.数据测试

代码语言:javascript复制
# 测试图片
def evaluate_one_image(image_array):
global graph
graph = tf.get_default_graph()
with graph.as_default():
  BATCH_SIZE = 1
  N_CLASSES = 2
  #image = tf.cast(image_array, tf.float32)

  x = tf.placeholder(tf.float32, shape=[1,150, 150, 3])

  logit = model.inference(x, BATCH_SIZE, N_CLASSES,1)

  logit = tf.nn.softmax(logit)

4.输出结果:

代码语言:javascript复制
prediction = sess.run(logit,feed_dict={x: img_array})
max_index = np.argmax(prediction)
# print(max_index)
# 标签映射可以选择字典或者列表
label_dict = {0: 'cat', 1: 'dog'}
label_list = ['cat','dog']
print("模型的输出为{},对应的真实标签为:{}".format(max_index,label_list[max_index]))

全部的测试代码:

实际预测展示

可以看到我们读取的是测试中的dog的图片,随后网络的预测标签是1,当初给dog的标签为1,即映射实际标签为dog,预测正确。

结语

本次分享结束了,算是图像分类项目的一个完整流程的项目,从数据处理到网络搭建,到训练,到调用模型做预测,我们都进行了分享,同时对代码细节进行了注释,相信聪敏的你一定可以看懂,如有疑惑请随时后台哦。

虽然本次项目结束,但我相信,其中或多或少有些地方大家不太理解,不管数据处理还是网络的搭建等等都或许不是那么简单,没关系,下次,小编会针对本次项目中的漏洞进行一个总结,算是图像分类项目的总结篇吧,同时也欢迎各位老铁,多多提问,以促使我们一起进步。

周末愉快,我们下期再见!

编辑:玥怡居士|审核:小圈圈居士

0 人点赞