pytorch DataLoader(3)_albumentations数据增强(分割版)

2021-07-07 18:19:43 浏览数 (1)

本文代码

系列前置文章:

pytorch DataLoader(1): opencv,skimage,PIL,Tensor转换以及transforms

pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口

翻译文章: 将Albumentations用于语义分割任务


这篇文章主要是讲怎么利用albumentations来做数据增强的,torchvision的transforms模块本身就包含了很多的数据增强功能,在这里讲解albumentations的原因是albumentations的速度比其他一些数据增强的方法普遍更快一点(主要卖点速度快),功能更齐全。

详情见官方文档·英文,可以查看github

Albumentations的主要特点:

  • 这个库是图片处理的library,处理的图片是在HWC格式下,也就是Height,Width,Channale;
  • 在相同的对图像的处理下,使用这个库函数的速度更快
  • 基于numpy和OpenCV,这个库从中取其精华;
  • 相比torch自带的,这个库函数有更多的对图像的预处理的办法
  • 对Pytorch很友好,而且这个库函数是kaggle master制作的。

一些前置知识可以参考第一篇文章pytorch DataLoader(1): opencv,skimage,PIL,Tensor转换以及transforms,这篇文章主要讲了一些opencv,skimage,PIL的格式,读取方式,相互转换等,有助于帮助大家理解本文本文的一些操作等。

pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口

NOTE: 时间紧可以直接看第二点数据增强部分

代码语言:javascript复制
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
import albumentations as A
# from albumentations.pytorch.transforms import ToTensorV2,ToTensor
from albumentations.pytorch import ToTensorV2,ToTensor

1. 读取文件路径

从保存image路径的txt文件中读取path,并保存到list中。

代码语言:javascript复制
tra_img_name_list = []
fg_list_name = 'image.txt'
with open(fg_list_name, 'r') as reader:
    path_list = reader.readlines()
    for line in path_list:
        line = line.replace('n', '').replace('\', '/')
        tra_img_name_list.append(line)
        
tra_lbl_name_list = []
for img_path in tra_img_name_list: # 获取所有mask文件地址
    img_name = img_path.split(os.sep)[-1]
    aaa = img_name.split(".")[0]
    tra_lbl_name_list.append('data/'   aaa   '.png')

2. 数据增强

其他代码跟pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口中基本相同,这篇文章主要是修改数据增强这块,使用Albumentations来做数据增强。

对image,alpha图片进行缩放,裁剪,转tensor操作,详情见注释。

使用Albumentations来做数据增强,可以直接使用opencv读取图像,记得BGR转RGB就行。PIL读取的图像也可以,只是需要转成numpy格式的。https://albumentations.ai/docs/getting_started/image_augmentation/

from PIL import Image import numpy as np Read an image with Pillow and convert it to a NumPy array pillow_image = Image.open(“image.jpg”) image = np.array(pillow_image)

Colab notebook示例

GitHub notebook示例

代码语言:javascript复制
'''
salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transforms.Compose([
                            RescaleT(320),
                            RandomCrop(288),
                            ToTensor()]))
'''
class transformA(object):
    def __init__(self, scale_size=320, output_size=288):
        assert isinstance(scale_size, (int, tuple))
        assert isinstance(output_size, (int, tuple))
        
        self.scale_size = scale_size
                
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2 # tuple
            self.output_size = output_size
            
    def __call__(self, sample):
        imidx, image, label = sample['imidx'], sample['image'],sample['label']
        
        # crop size
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        
        # aug
#         aug = A.Compose([
#             A.HorizontalFlip(p=0.5),
#             A.VerticalFlip(p=0.5),
#             A.Resize(height=self.scale_size, width=self.scale_size, interpolation=3, always_apply=False, p=1),
#             A.RandomCrop(height=new_h, width=new_w, p=1),
#             A.Normalize(
#                 mean=[0.485, 0.456, 0.406],
#                 std=[0.229, 0.224, 0.225],
#                 ),
#             ])
        # 为了显示效果好 先不做normalize
        aug = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Resize(height=self.scale_size, width=self.scale_size, interpolation=3, always_apply=False, p=1),
            A.RandomCrop(height=new_h, width=new_w, p=1),
            ])
        
        augmented = aug(image=image, mask=label)

        image = augmented['image']
        label = augmented['mask']
        
        #aug2 = A.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)
        aug2 = A.Compose([
            A.OneOf([
                A.HueSaturationValue(hue_shift_limit=0.1, sat_shift_limit= 0.3,
                                     val_shift_limit=0.3, p=0.9),
                A.RandomBrightnessContrast(brightness_limit=0.4,
                                           contrast_limit=0.3, p=0.9),
            ],p=0.9),
            A.ToGray(p=0.05),
            A.OneOf([
                A.IAAAdditiveGaussianNoise(),   # 将高斯噪声添加到输入图像
                A.GaussNoise(),    # 将高斯噪声应用于输入图像。
                ], p=0.2),
                ])
        augmented2 = aug2(image=image)
        image = augmented2['image']

        image = transforms.ToTensor()(image)
        label = transforms.ToTensor()(label)
        
        return {'imidx':torch.from_numpy(imidx),'image':image, 'label':label}               

使用A.Resize()来进行缩放/将输入图像调整为给定的高度和宽度的时候,默认参数是A.Resize(height, width, interpolation=1, always_apply=False, p=1)

interpolation的选择可以参考Resizing transforms

因为这里是要缩小,为了避免出现波纹现象,所以最好使用区域插值**cv2.INTER_AREA**,即interpolation=3,其他参数不变。

根据官方文档应该使用ToTensorV2()但是结果图上显示图片出现了很多红点,其次不管是ToTensorV2()orToTensor(),均出现了通道错误的现象,暂时还没找出问题所在。所以最后还是使用了transforms.ToTensor()

代码语言:javascript复制
#     trans = A.Compose([
#         A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.5),
#         A.OneOf([
#             A.IAAAdditiveGaussianNoise(),   # 将高斯噪声添加到输入图像
#             A.GaussNoise(),    # 将高斯噪声应用于输入图像。
#         ], p=0.2),   # 应用选定变换的概率
#         A.OneOf([
#             A.MotionBlur(p=0.2),   # 使用随机大小的内核将运动模糊应用于输入图像。
#             A.MedianBlur(blur_limit=3, p=0.1),    # 中值滤波
#             A.Blur(blur_limit=3, p=0.1),   # 使用随机大小的内核模糊输入图像。
#         ], p=0.2),
#         A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
#         # 随机应用仿射变换:平移,缩放和旋转输入
#         A.RandomBrightnessContrast(p=0.2),   # 随机明亮对比度
#     ])

3. 制作datasets

代码语言:javascript复制
class SalObjDataset(Dataset):
    def __init__(self,img_name_list,lbl_name_list,transform=None):
        self.image_name_list = img_name_list
        self.label_name_list = lbl_name_list
        self.transform = transform

    def __len__(self):
        return len(self.image_name_list)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_name_list[idx])
        imname = self.image_name_list[idx]
        imidx = np.array([idx])  # [idx]

        if (len(self.label_name_list) == 0): # inference: label_name_list = []/None
            label_3 = np.zeros(image.shape)
        else: # train
            label_3 = cv2.imread(self.label_name_list[idx])  ###

        # 正确读取单通道label
        label = np.zeros(label_3.shape[0:2]) # copy zeros shape | just get HW
        if (len(label_3.shape) == 3):
            label = label_3[:, :, 0]  # H*W 1 channel or 到最后再转
        elif (len(label_3.shape) == 2):
            label = label_3

        # make sure label(...,...,1)
        if (len(image.shape) == 3) and (len(label.shape) == 2):
            label = label[:,:,np.newaxis]
            ### image BGR2RGB
            #image = image[:, :, ::-1] ###
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # opencv的通道转换不要忘记
        elif (len(image.shape) == 2) and (len(label.shape == 2)): #
            image = image[:,:, np.newaxis]
            label = label[:, :, np.newaxis]

        sample = {'imidx': imidx, 'image': image, 'label': label}

        if self.transform:
            sample = self.transform(sample)

        return sample

4. DataLoader

代码语言:javascript复制
# salobj_dataset = SalObjDataset(
#     img_name_list=tra_img_name_list,
#     lbl_name_list=tra_lbl_name_list,
#     transform=transforms.Compose([
#                             RescaleT(320),
#                             RandomCrop(288),
#                             ToTensor()]))
salobj_dataset = SalObjDataset(
    img_name_list=tra_img_name_list,
    lbl_name_list=tra_lbl_name_list,
    transform=transformA(scale_size=320, output_size=288))

salobj_dataloader = DataLoader(salobj_dataset,
                               batch_size=2,
                               shuffle=True,
                               num_workers=4,
                               drop_last=True) 

以上主要是以opencv为例子来读取数据并加载的。

关于skimage怎么读取数据加载,数据增强等,可以直接参考U2Net的代码1。上面的代码就是改写自u2net训练和dataloader的代码。

后续PIL的有机会再补充,搞清楚了各种接口之间的关系和相互转换,其本质都是一样的。

5. 测试并显示

代码语言:javascript复制
# 辅助函数,用于展示一个 batch 的数据
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = 
            sample_batched['image'], sample_batched['label']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy()   i * im_size   (i   1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy()   grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from salobj_dataloader')

for i_batch, sample_batched in enumerate(salobj_dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['label'].size())

    # observe 4th batch and stop.
    if i_batch == 1:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break
代码语言:javascript复制
0 torch.Size([2, 3, 288, 288]) torch.Size([2, 1, 288, 288])
1 torch.Size([2, 3, 288, 288]) torch.Size([2, 1, 288, 288])
代码语言:javascript复制

Reference:

1 https://cloud.tencent.com/developer/article/1660972

2(https://albumentations.ai/docs/getting_started/mask_augmentation/)

3 http://aix.51cto.com/blog/70799.html

4 数据来源:爱分割 github

5(https://blog.csdn.net/qq_27039891/article/details/100795846)

6 https://zhuanlan.zhihu.com/p/371761014

7 https://zhuanlan.zhihu.com/p/107399127

8 https://www.aiuai.cn/aifarm1380.html

0 人点赞