更新
本文进入热榜收到了不少关注,所以将本文的代码放在了GitHub上,jupyter的,有需要的自取。
同时也欢迎查看后续更新:
pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口
pytorch DataLoader(3)_albumentations数据增强(分割版)
前置知识
在使用pytorch进行dataload,transform之前,需要了解一些数据的知识,许多人使用不同的接口因为不熟悉犯了一些错误。在这里对一些常用的OpenCV,PIL,skimage进行了一些总结,以及pytorchvision.transorforms的一些简单使用。
代码语言:javascript复制import cv2
from PIL import Image
from skimage import io, transform, color
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
代码语言:javascript复制img_path = 'data/1803151818-00000065.jpg'
alpha_path = 'data/1803151818-00000065.png'
常用接口
1.1 OpenCV
代码语言:javascript复制# 默认彩图
img_cv2 = cv2.imread(img_path)
# 灰度图
img_cv2_gray = cv2.imread(alpha_path,0)
print(img_cv2.shape)
# (250, 250, 3) (H,W,C)
type(img_cv2)
# numpy.ndarray
1.2 PIL.Image
代码语言:javascript复制# 默认彩图
img_pil = Image.open(img_path)
# 灰度图
img_pil_gray = Image.open(alpha_path).convert('L') # 打开图片并转成灰度图
print(img_pil.size)
# (250, 250)
print(np.array(img_pil).shape) # PIL没有shape属性,需要转成 numpy.ndarray
#(250, 250, 3)
type(img_pil)
# PIL.JpegImagePlugin.JpegImageFile HWC
1.3 skimage1
代码语言:javascript复制# 默认彩图
img_skimage = io.imread(img_path)
# 灰度图
img_skimage_gray = io.imread(alpha_path,-1)
print(img_skimage.shape)
# (250, 250, 3)
type(img_skimage)
# numpy.ndarray
# imageio.core.util.Array
代码语言:javascript复制(800, 600, 3)
numpy.ndarray
1.4 小结
- OpenCV读进来的是numpy数组,是uint8类型,0-255范围,图像形状是(H,W,C),读入的顺序是BGR,这点需要注意
- PIL是有自己的数据结构的,类型是;但是可以转换成numpy数组,转换后的数组为unit8,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
- skimage读取进来的图片是numpy数组,是unit8类型,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
- matplotlib读取进来的图片是numpy数组,是unit8类型,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
名称 | type | 数据类型 | 读入图像格式 | 数据形状 | 能否通过transforms转换 |
---|---|---|---|---|---|
opencv | numpy.ndarray | uint8类型,0-255范围 | BGR | H×W×C | 否 |
PIL | PIL.Image.Image | | RGB | H×W×C | 是 |
skimage | numpy.ndarray | uint8类型,0-255范围 | RGB | H×W×C | 否 |
#cv2
# cv2 BGR-->RGB 两种方法
#img_cv2 = img_cv2[:,:,::-1]
img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
plt.subplot(1,4,1)
plt.title('cv2')
plt.imshow(img_cv2)
#PIL
plt.subplot(1,4,2)
plt.title('PIL')
plt.imshow(img_pil)
#PIL
plt.subplot(1,4,3)
plt.title('skimage')
plt.imshow(img_skimage)
#plt
img = plt.imread(img_path)
plt.subplot(1,4,4)
plt.title('plt')
plt.imshow(img_pil)
#show
plt.show()
2. 相互转换
2.1 opencv <—> pil
代码语言:javascript复制img_cv = cv2.imread(img_path)
img_pil = Image.open(img_path)
img_skimage = io.imread(img_path)
# opencv -> pil
img_pil = Image.fromarray(cv2.cvtColor(img_cv,cv2.COLOR_BGR2RGB))
# pil -> opencv
img_cv = cv2.cvtColor(np.asarray(img_pil),cv2.COLOR_RGB2BGR)
2.2 skimage <—> pil
代码语言:javascript复制# skimage -> pil
img_pil = Image.fromarray(img_skimage)
# pil -> skimage
img_pil = np.array(img_skimage)
2.3 skimage <—> opencv
代码语言:javascript复制# opencv -> skimage
img_skimage = cv2.cvtColor(img_cv,cv2.COLOR_BGR2RGB)
# skimage -> opencv
from skimage import img_as_ubyte
cv_image = img_as_ubyte(img_skimage)
3. transforms, tensor转换
为了方便进行图像数据的操作,pytorch团队提供了一个torchvision.transforms包,我们可以用transforms进行以下操作:
- PIL.Image / numpy.ndarray与Tensor的相互转化;
- 归一化;
- 对PIL.Image进行裁剪、缩放等操作。
注意1: transforms.ToTensor()
可以将 PIL.Image/numpy.ndarray 数据进转化为torch.FloatTensor,并归一化到0, 1.0,但是transforms的其他操作只能对PIL读入的数据操作,所以使用transforms.Compose()
将这些操作组合到一起的如果有其他操作则只能输入PIL数据。
transforms包含多种图像操作的函数,可以单独使用,也可以通过transforms.Compose(function1, function2,……functionN)操作。
注意2:Tensor的形状是C,H,W,而cv2,plt,PIL,skimage形状都是H,W,C
3.1 H×W×C ——> C×H×W
代码语言:javascript复制img_cv2.transpose(2,0,1).shape
# (3,250, 250)
img_skimage.transpose(2,0,1).shape
# (3,250, 250)
代码语言:javascript复制(3, 800, 600)
3.2 toTensor
- PIL.Image / numpy.ndarray --> Tensor: train 数据读取
- Tensor --> PIL.Image / numpy.ndarray: inference 数据输出。
我们可以使用 transforms.ToTensor() 将 PIL.Image/numpy.ndarray 数据进转化为torch.FloatTensor,并归一化到0, 1.0:
- 取值范围为0, 255的PIL.Image,转换成形状为C, H, W,取值范围是0, 1.0的torch.FloatTensor;
- 形状为H, W, C的numpy.ndarray,转换成形状为C, H, W,取值范围是0, 1.0的torch.FloatTensor;
- 而
transforms.ToPILImage
则是将Tensor或numpy.ndarray转化为PIL.Image。如果,我们要将Tensor转化为numpy,只需要使用 .numpy() 即可。
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
img_path = 'data/1803151818-00000065.jpg'
# transforms.ToTensor()
transform1 = transforms.Compose([
transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] and convert [H,W,C] to [C,H,W]
])
img = plt.imread(img_path)
print('plt',img.shape) #(H,W,C)
img = transform1(img)
print(img.shape) #torch.Size([C,H,W])
# 转化为numpy.ndarray并显示
img_arr = img.numpy() * 255 #use np.numpy(): convert Tensor to numpy
img_arr = img_arr.astype('uint8') #convert Float to Int
print(img_arr.shape) #[C,H,W]
img_new = np.transpose(img_arr, (1, 2, 0)) #use np.transpose() convert [C,H,W] to [H,W,C]
plt.imshow(img_new)
plt.show()
代码语言:javascript复制plt (800, 600, 3)
torch.Size([3, 800, 600])
(3, 800, 600)
代码语言:javascript复制img = cv2.imread(img_path)
#img = img[:,:,::-1] ### ValueError???
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
print('plt',img.shape) #(H,W,C)
img = transform1(img)
print(img.shape) #torch.Size([C,H,W])
# 转化为numpy.ndarray并显示
img_arr = img.numpy() * 255 #use np.numpy(): convert Tensor to numpy
img_arr = img_arr.astype('uint8') #convert Float to Int
print(img_arr.shape) #[C,H,W]
img_new = np.transpose(img_arr, (1, 2, 0)) #use np.transpose() convert [C,H,W] to [H,W,C]
plt.imshow(img_new)
plt.show()
代码语言:javascript复制plt (800, 600, 3)
torch.Size([3, 800, 600])
(3, 800, 600)
3.3 Normalize
代码语言:txt复制 c h a n n e l = c h a n n e l − m e a n s t d channel = frac{channel - mean}{std} channel=stdchannel−mean进行规范化。(是对tensor进行归一化,所以需要放在transforms.ToTensor()之后)
代码语言:javascript复制mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
# 这两组值是 ImageNet数据集大样本统计得出的
#归一化
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
]
)
3.4 compose
代码语言:javascript复制normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# 先转PIL 再进入Compose 进行数据增强
all_transforms = transforms.Compose([
transforms.Resize(256),
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(), # 对PIL.Image图片进行操作
transforms.ToTensor(),
normalize])
代码语言:javascript复制# 或者ToTensor之后 再转PIL
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomCrop((300,300)),
])
img = Image.open(img_path).convert('RGB')
img2 = transform2(img)
img2.show()
Reference:
数据来源:爱分割 github
https://blog.csdn.net/tsq292978891/article/details/78767326
- Image data types and what they mean ↩︎