PIL Image与tensor在PyTorch图像预处理时的转换

2021-09-18 15:23:38 浏览数 (1)

前言:在使用深度学习框架PyTorch预处理图像数据时,你可能和我一样遇到过各种各样的问题,网上虽然总能找到类似的问题,但不同文章的代码环境不同,也不一定能直接解决自己的问题。这时,就需要就自身所出bug了解问题本身涉及的大致原理,依据报错的具体位置(要完整的看完bug信息,不要只看最后报错信息而不看中间调用过程)才能更快的精准解决自己的问题

一、原理概述

PIL(Python Imaging Library)是Python中最基础的图像处理库,而使用PyTorch将原始输入图像预处理为神经网络的输入,经常需要用到三种格式PIL Image、Numpy和Tensor,其中预处理包括但不限于「图像裁剪」,「图像旋转」和「图像数据归一化」等。而对图像的多种处理在code中可以打包到一起执行,一般用transforms.Compose(transforms)将多个transform组合起来使用。如下所示

代码语言:javascript复制
from torchvision import transforms 

transform = transforms.Compose([
			   # 重置大小
			   transforms.Resize(255), 
               transforms.CenterCrop(224),  
               # 随机旋转图片
               transforms.RandomHorizontalFlip(),
               transforms.ToTensor(), 
               # 正则化(降低模型复杂度)
               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

其中,不同图像处理方法要求输入的图像格式不同,比如Resize()RandomHorizontalFlip()等方法要求输入的图像为PIL Image,而正则化操作Normalize()处理的是tensor格式的图像数据。因此,针对不同操作的数据格式要求,我们需要在不同操作之前将输入图像数据的格式化成所要求的格式,有了这些概念了解,面对可能出现的bug,我们才能游刃有余的精准处理。

二、PIL Image与tensor的转换

2.1 tensor转换为PIL Image
代码语言:javascript复制
from torchvision.transforms 
PIL_img = transforms.ToPILImage()(tensor_img) 
2.2 PIL Image转换为tensor

一般放在transforms.Compose(transforms)组合中正则化操作的前面即可

代码语言:javascript复制
transforms.ToTensor()
2.3 Numpy转换为PIL Image
代码语言:javascript复制
from PIL import Image
PIL_img = Image.fromarray(array)

三、可能遇到的问题

3.1 img should be PIL Image. Got <class ‘torch.Tensor’>
代码语言:javascript复制
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

这个问题,网上大部分博文甚至stackoverflow上说的都是transforms.Compose(transforms)组合中的顺序问题,但按照这些说法修改顺序后我仍一直未解决问题。后来了解了原理并结合自己实际bug出现的位置,才最终解决。

如下图所示,我的bug出现在红框中的句柄中,而与大多数博文不同的是,我是先对图像做灰度处理,然后再做剪裁和旋转的操作,因此transforms.Compose(transforms)组合操作在这行代码之后,自然怎么改顺序都无动于衷。所以从bug的位置可知此问题与组合操作顺序无关,但从最后的类型错误中可知此行代码传进去的observation类型期望是PIL,但实际是tensor,因此只要在此之前进行两者格式的转换即可解决bug

解决方案从

代码语言:javascript复制
transform = T.Grayscale()
img = transform(img)

变为

代码语言:javascript复制
transform = T.Grayscale()
img = T.ToPILImage()(img)
img = transform(img)
3.1 tensor should be a torch tensor. Got <class ‘PIL.Image.Image’>.
代码语言:javascript复制
TypeError: tensor should be a torch tensor. Got <class 'PIL.Image.Image'>.

肯定是需要tensor的图像操作传入的是PIL,因此在合适的位置前将PIL转换为tensor即可

解决方法从

代码语言:javascript复制
transform = transforms.Compose([
			   transforms.Resize(255), 
               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

代码语言:javascript复制
transform = transforms.Compose([
			   transforms.Resize(255), 
               transforms.ToTensor(), 
               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

参考文献

[1] PIL.Image和np.ndarray图片与Tensor之间的转换 [2] PyTorch载入图片后ToTensor解读(含PIL和OpenCV读取图片对比) [3] pytorch如何显示数据图像及标签TypeError: img should be PIL Image. Got <class ‘numpy.ndarray‘>

0 人点赞