pytorch DataLoader(1): opencv,skimage,PIL,Tensor转换以及transforms

2021-07-07 18:19:03 浏览数 (1)

更新

本文进入热榜收到了不少关注,所以将本文的代码放在了GitHub上,jupyter的,有需要的自取。

同时也欢迎查看后续更新:

pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口

pytorch DataLoader(3)_albumentations数据增强(分割版)

前置知识

在使用pytorch进行dataload,transform之前,需要了解一些数据的知识,许多人使用不同的接口因为不熟悉犯了一些错误。在这里对一些常用的OpenCV,PIL,skimage进行了一些总结,以及pytorchvision.transorforms的一些简单使用。

代码语言:javascript复制
import cv2
from PIL import Image
from skimage import io, transform, color
import matplotlib.pyplot as plt
import numpy as np

from torchvision import transforms
代码语言:javascript复制
img_path = 'data/1803151818-00000065.jpg'
alpha_path = 'data/1803151818-00000065.png'

常用接口

1.1 OpenCV
代码语言:javascript复制
# 默认彩图
img_cv2 = cv2.imread(img_path)

# 灰度图
img_cv2_gray = cv2.imread(alpha_path,0)

print(img_cv2.shape)
# (250, 250, 3)  (H,W,C)

type(img_cv2)
# numpy.ndarray
1.2 PIL.Image
代码语言:javascript复制
# 默认彩图
img_pil = Image.open(img_path)

# 灰度图
img_pil_gray = Image.open(alpha_path).convert('L') # 打开图片并转成灰度图

print(img_pil.size) 
# (250, 250)

print(np.array(img_pil).shape) # PIL没有shape属性,需要转成 numpy.ndarray
#(250, 250, 3)

type(img_pil)
# PIL.JpegImagePlugin.JpegImageFile  HWC
1.3 skimage1
代码语言:javascript复制
# 默认彩图
img_skimage = io.imread(img_path)

# 灰度图
img_skimage_gray = io.imread(alpha_path,-1)

print(img_skimage.shape)
# (250, 250, 3)

type(img_skimage)
# numpy.ndarray
# imageio.core.util.Array
代码语言:javascript复制
(800, 600, 3)

numpy.ndarray
1.4 小结
  • OpenCV读进来的是numpy数组,是uint8类型,0-255范围,图像形状是(H,W,C),读入的顺序是BGR,这点需要注意
  • PIL是有自己的数据结构的,类型是;但是可以转换成numpy数组,转换后的数组为unit8,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
  • skimage读取进来的图片是numpy数组,是unit8类型,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
  • matplotlib读取进来的图片是numpy数组,是unit8类型,0-255范围,图像形状是(H,W,C),读入的顺序是RGB

名称

type

数据类型

读入图像格式

数据形状

能否通过transforms转换

opencv

numpy.ndarray

uint8类型,0-255范围

BGR

H×W×C

PIL

PIL.Image.Image

RGB

H×W×C

skimage

numpy.ndarray

uint8类型,0-255范围

RGB

H×W×C

代码语言:javascript复制
#cv2
# cv2 BGR-->RGB 两种方法
#img_cv2 = img_cv2[:,:,::-1]
img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)

plt.subplot(1,4,1)
plt.title('cv2')
plt.imshow(img_cv2)
 
#PIL
plt.subplot(1,4,2)
plt.title('PIL')
plt.imshow(img_pil)

#PIL
plt.subplot(1,4,3)
plt.title('skimage')
plt.imshow(img_skimage)

#plt
img = plt.imread(img_path)
plt.subplot(1,4,4)
plt.title('plt')
plt.imshow(img_pil)

#show
plt.show()

2. 相互转换

2.1 opencv <—> pil
代码语言:javascript复制
img_cv = cv2.imread(img_path)
img_pil = Image.open(img_path)
img_skimage = io.imread(img_path)

# opencv -> pil
img_pil = Image.fromarray(cv2.cvtColor(img_cv,cv2.COLOR_BGR2RGB))

# pil -> opencv
img_cv = cv2.cvtColor(np.asarray(img_pil),cv2.COLOR_RGB2BGR)
2.2 skimage <—> pil
代码语言:javascript复制
# skimage -> pil
img_pil = Image.fromarray(img_skimage)

# pil -> skimage
img_pil = np.array(img_skimage)
2.3 skimage <—> opencv
代码语言:javascript复制
# opencv -> skimage
img_skimage = cv2.cvtColor(img_cv,cv2.COLOR_BGR2RGB)

# skimage -> opencv
from skimage import img_as_ubyte
cv_image = img_as_ubyte(img_skimage)

3. transforms, tensor转换

为了方便进行图像数据的操作,pytorch团队提供了一个torchvision.transforms包,我们可以用transforms进行以下操作:

  • PIL.Image / numpy.ndarray与Tensor的相互转化;
  • 归一化;
  • 对PIL.Image进行裁剪、缩放等操作。

注意1: transforms.ToTensor() 可以将 PIL.Image/numpy.ndarray 数据进转化为torch.FloatTensor,并归一化到0, 1.0,但是transforms的其他操作只能对PIL读入的数据操作,所以使用transforms.Compose()将这些操作组合到一起的如果有其他操作则只能输入PIL数据。

transforms包含多种图像操作的函数,可以单独使用,也可以通过transforms.Compose(function1, function2,……functionN)操作。

注意2:Tensor的形状是C,H,W,而cv2,plt,PIL,skimage形状都是H,W,C

3.1 H×W×C ——> C×H×W
代码语言:javascript复制
img_cv2.transpose(2,0,1).shape
# (3,250, 250)

img_skimage.transpose(2,0,1).shape
# (3,250, 250)
代码语言:javascript复制
(3, 800, 600)
3.2 toTensor
  • PIL.Image / numpy.ndarray --> Tensor: train 数据读取
  • Tensor --> PIL.Image / numpy.ndarray: inference 数据输出。

我们可以使用 transforms.ToTensor() 将 PIL.Image/numpy.ndarray 数据进转化为torch.FloatTensor,并归一化到0, 1.0:

  • 取值范围为0, 255的PIL.Image,转换成形状为C, H, W,取值范围是0, 1.0的torch.FloatTensor;
  • 形状为H, W, C的numpy.ndarray,转换成形状为C, H, W,取值范围是0, 1.0的torch.FloatTensor;
  • transforms.ToPILImage则是将Tensor或numpy.ndarray转化为PIL.Image。如果,我们要将Tensor转化为numpy,只需要使用 .numpy() 即可。
代码语言:javascript复制
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
 
img_path = 'data/1803151818-00000065.jpg'
 
# transforms.ToTensor()
transform1 = transforms.Compose([
    transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0] and convert [H,W,C] to [C,H,W]
])
img = plt.imread(img_path)
print('plt',img.shape)       #(H,W,C)
img = transform1(img)  
print(img.shape)         #torch.Size([C,H,W])
# 转化为numpy.ndarray并显示
img_arr = img.numpy() * 255  #use np.numpy(): convert Tensor to numpy
img_arr = img_arr.astype('uint8')  #convert Float to Int
print(img_arr.shape)                #[C,H,W]
img_new = np.transpose(img_arr, (1, 2, 0))   #use np.transpose() convert [C,H,W] to [H,W,C]
plt.imshow(img_new)
plt.show()
代码语言:javascript复制
plt (800, 600, 3)
torch.Size([3, 800, 600])
(3, 800, 600)
代码语言:javascript复制
img = cv2.imread(img_path)
#img = img[:,:,::-1] ### ValueError???
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

print('plt',img.shape)       #(H,W,C)
img = transform1(img)  
print(img.shape)         #torch.Size([C,H,W])
# 转化为numpy.ndarray并显示
img_arr = img.numpy() * 255  #use np.numpy(): convert Tensor to numpy
img_arr = img_arr.astype('uint8')  #convert Float to Int
print(img_arr.shape)                #[C,H,W]
img_new = np.transpose(img_arr, (1, 2, 0))   #use np.transpose() convert [C,H,W] to [H,W,C]
plt.imshow(img_new)
plt.show()
代码语言:javascript复制
plt (800, 600, 3)
torch.Size([3, 800, 600])
(3, 800, 600)
3.3 Normalize
代码语言:txt复制
                               c                         h                         a                         n                         n                         e                         l                         =                                              c                               h                               a                               n                               n                               e                               l                               −                               m                               e                               a                               n                                                 s                               t                               d                                                 channel = frac{channel - mean}{std}                  channel=stdchannel−mean​进行规范化。(是对tensor进行归一化,所以需要放在transforms.ToTensor()之后)
代码语言:javascript复制
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
# 这两组值是 ImageNet数据集大样本统计得出的

#归一化 
transform2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
    ]
)
3.4 compose
代码语言:javascript复制
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# 先转PIL 再进入Compose 进行数据增强
all_transforms = transforms.Compose([
                    transforms.Resize(256),
                    transforms.RandomSizedCrop(224),
                    transforms.RandomHorizontalFlip(), # 对PIL.Image图片进行操作
                    transforms.ToTensor(),
                    normalize])
代码语言:javascript复制
# 或者ToTensor之后 再转PIL
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomCrop((300,300)),
])
img = Image.open(img_path).convert('RGB')
 
img2 = transform2(img)
 
img2.show()

Reference:

数据来源:爱分割 github

https://blog.csdn.net/tsq292978891/article/details/78767326


  1. Image data types and what they mean ↩︎

0 人点赞