前言:在使用深度学习框架PyTorch预处理图像数据时,你可能和我一样遇到过各种各样的问题,网上虽然总能找到类似的问题,但不同文章的代码环境不同,也不一定能直接解决自己的问题。这时,就需要就自身所出bug了解问题本身涉及的大致原理,依据报错的具体位置(要完整的看完bug信息,不要只看最后报错信息而不看中间调用过程)才能更快的精准解决自己的问题
一、原理概述
PIL(Python Imaging Library)是Python中最基础的图像处理库,而使用PyTorch将原始输入图像预处理为神经网络的输入,经常需要用到三种格式PIL Image、Numpy和Tensor,其中预处理包括但不限于「图像裁剪」,「图像旋转」和「图像数据归一化」等。而对图像的多种处理在code中可以打包到一起执行,一般用transforms.Compose(transforms)
将多个transform组合起来使用。如下所示
from torchvision import transforms
transform = transforms.Compose([
# 重置大小
transforms.Resize(255),
transforms.CenterCrop(224),
# 随机旋转图片
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# 正则化(降低模型复杂度)
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
其中,不同图像处理方法要求输入的图像格式不同,比如Resize()
和RandomHorizontalFlip()
等方法要求输入的图像为PIL Image
,而正则化操作Normalize()
处理的是tensor
格式的图像数据。因此,针对不同操作的数据格式要求,我们需要在不同操作之前将输入图像数据的格式化成所要求的格式,有了这些概念了解,面对可能出现的bug,我们才能游刃有余的精准处理。
二、PIL Image与tensor的转换
2.1 tensor转换为PIL Image
代码语言:javascript复制from torchvision.transforms
PIL_img = transforms.ToPILImage()(tensor_img)
2.2 PIL Image转换为tensor
一般放在transforms.Compose(transforms)
组合中正则化操作的前面即可
transforms.ToTensor()
2.3 Numpy转换为PIL Image
代码语言:javascript复制from PIL import Image
PIL_img = Image.fromarray(array)
三、可能遇到的问题
3.1 img should be PIL Image. Got <class ‘torch.Tensor’>
代码语言:javascript复制TypeError: img should be PIL Image. Got <class 'torch.Tensor'>
这个问题,网上大部分博文甚至stackoverflow上说的都是transforms.Compose(transforms)
组合中的顺序问题,但按照这些说法修改顺序后我仍一直未解决问题。后来了解了原理并结合自己实际bug出现的位置,才最终解决。
如下图所示,我的bug出现在红框中的句柄中,而与大多数博文不同的是,我是先对图像做灰度处理,然后再做剪裁和旋转的操作,因此transforms.Compose(transforms)
组合操作在这行代码之后,自然怎么改顺序都无动于衷。所以从bug的位置可知此问题与组合操作顺序无关,但从最后的类型错误中可知此行代码传进去的observation类型期望是PIL,但实际是tensor,因此只要在此之前进行两者格式的转换即可解决bug
解决方案从
代码语言:javascript复制transform = T.Grayscale()
img = transform(img)
变为
代码语言:javascript复制transform = T.Grayscale()
img = T.ToPILImage()(img)
img = transform(img)
3.1 tensor should be a torch tensor. Got <class ‘PIL.Image.Image’>.
代码语言:javascript复制TypeError: tensor should be a torch tensor. Got <class 'PIL.Image.Image'>.
肯定是需要tensor
的图像操作传入的是PIL
,因此在合适的位置前将PIL
转换为tensor
即可
解决方法从
代码语言:javascript复制transform = transforms.Compose([
transforms.Resize(255),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
到
代码语言:javascript复制transform = transforms.Compose([
transforms.Resize(255),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
参考文献
[1] PIL.Image和np.ndarray图片与Tensor之间的转换 [2] PyTorch载入图片后ToTensor解读(含PIL和OpenCV读取图片对比) [3] pytorch如何显示数据图像及标签TypeError: img should be PIL Image. Got <class ‘numpy.ndarray‘>