报错信息分析
今天用训练好的CNN模型,对cifar10测试数据集的测试,报如下错误:
代码语言:javascript复制Traceback (most recent call last):
File "F:SoftwarePython_ProjectClassification-cifar10cifar10_test.py", line 92, in <module>
evaluate()
File "F:SoftwarePython_ProjectClassification-cifar10cifar10_test.py", line 76, in evaluate
classification_result = sess.run(logits,feed_dict)
File "C:Python36libsite-packagestensorflowpythonclientsession.py", line 887, in run
run_metadata_ptr)
File "C:Python36libsite-packagestensorflowpythonclientsession.py", line 1086, in _run
str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (32, 32, 3) for Tensor 'x:0', which has shape '(?, 32, 32, 3)'
[Finished in 20.6s with exit code 1]
通过报错信息,我们可以分析出feed_dict的值与定义的输入数据张量x格式不匹配,feed_dict的维度3维的,shape是(32,32,3),而x的维度是4维的,shape是(None,32,32,3)。所以,导致出现了维度不匹配的问题。
解决办法
既然维度不匹配,那我们就通过程序让它匹配,加入以下代码:
代码语言:javascript复制image = tf.reshape(image_data, [1, 32, 32, 3])
#输出要经过np.sum函数,才能得到类别编号
#file_label = np.sum(file_label)
问题总结
其实,之前做图像识别的比赛和项目是没出现过这种问题的,因为之前的测试数据集都比较小,我可以直接把测试图像加载成np.ndarray的数据类型,但是这里cifar10测试数据集有300000张图片,也就是说如果一次性全部读取,最后得到的image_data的shape将会是(300000,32,32,3),无疑太大了,我笔记本直接报错ran out of memory,所以我就选择一张张图片读取并预测,但是读出来的shape是3维的,没有设置输入batch,所以出现了这个问题。