本文代码
系列前置文章:
pytorch DataLoader(1): opencv,skimage,PIL,Tensor转换以及transforms
pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口
翻译文章: 将Albumentations用于语义分割任务
这篇文章主要是讲怎么利用albumentations来做数据增强的,torchvision的transforms模块本身就包含了很多的数据增强功能,在这里讲解albumentations的原因是albumentations的速度比其他一些数据增强的方法普遍更快一点(主要卖点速度快),功能更齐全。
详情见官方文档·英文,可以查看github
Albumentations的主要特点:
- 这个库是图片处理的library,处理的图片是在HWC格式下,也就是Height,Width,Channale;
- 在相同的对图像的处理下,使用这个库函数的速度更快;
- 基于numpy和OpenCV,这个库从中取其精华;
- 相比torch自带的,这个库函数有更多的对图像的预处理的办法
- 对Pytorch很友好,而且这个库函数是kaggle master制作的。
一些前置知识可以参考第一篇文章pytorch DataLoader(1): opencv,skimage,PIL,Tensor转换以及transforms,这篇文章主要讲了一些opencv,skimage,PIL的格式,读取方式,相互转换等,有助于帮助大家理解本文本文的一些操作等。
pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口
NOTE: 时间紧可以直接看第二点数据增强部分
代码语言:javascript复制import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
import albumentations as A
# from albumentations.pytorch.transforms import ToTensorV2,ToTensor
from albumentations.pytorch import ToTensorV2,ToTensor
1. 读取文件路径
从保存image路径的txt文件中读取path,并保存到list中。
代码语言:javascript复制tra_img_name_list = []
fg_list_name = 'image.txt'
with open(fg_list_name, 'r') as reader:
path_list = reader.readlines()
for line in path_list:
line = line.replace('n', '').replace('\', '/')
tra_img_name_list.append(line)
tra_lbl_name_list = []
for img_path in tra_img_name_list: # 获取所有mask文件地址
img_name = img_path.split(os.sep)[-1]
aaa = img_name.split(".")[0]
tra_lbl_name_list.append('data/' aaa '.png')
2. 数据增强
其他代码跟pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口中基本相同,这篇文章主要是修改数据增强这块,使用Albumentations来做数据增强。
对image,alpha图片进行缩放,裁剪,转tensor操作,详情见注释。
使用Albumentations来做数据增强,可以直接使用opencv读取图像,记得BGR转RGB就行。PIL读取的图像也可以,只是需要转成numpy格式的。https://albumentations.ai/docs/getting_started/image_augmentation/
from PIL import Image import numpy as np Read an image with Pillow and convert it to a NumPy array pillow_image = Image.open(“image.jpg”) image = np.array(pillow_image)
Colab notebook示例
GitHub notebook示例
代码语言:javascript复制'''
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
RandomCrop(288),
ToTensor()]))
'''
class transformA(object):
def __init__(self, scale_size=320, output_size=288):
assert isinstance(scale_size, (int, tuple))
assert isinstance(output_size, (int, tuple))
self.scale_size = scale_size
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2 # tuple
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample['imidx'], sample['image'],sample['label']
# crop size
h, w = image.shape[:2]
new_h, new_w = self.output_size
# aug
# aug = A.Compose([
# A.HorizontalFlip(p=0.5),
# A.VerticalFlip(p=0.5),
# A.Resize(height=self.scale_size, width=self.scale_size, interpolation=3, always_apply=False, p=1),
# A.RandomCrop(height=new_h, width=new_w, p=1),
# A.Normalize(
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225],
# ),
# ])
# 为了显示效果好 先不做normalize
aug = A.Compose([
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Resize(height=self.scale_size, width=self.scale_size, interpolation=3, always_apply=False, p=1),
A.RandomCrop(height=new_h, width=new_w, p=1),
])
augmented = aug(image=image, mask=label)
image = augmented['image']
label = augmented['mask']
#aug2 = A.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)
aug2 = A.Compose([
A.OneOf([
A.HueSaturationValue(hue_shift_limit=0.1, sat_shift_limit= 0.3,
val_shift_limit=0.3, p=0.9),
A.RandomBrightnessContrast(brightness_limit=0.4,
contrast_limit=0.3, p=0.9),
],p=0.9),
A.ToGray(p=0.05),
A.OneOf([
A.IAAAdditiveGaussianNoise(), # 将高斯噪声添加到输入图像
A.GaussNoise(), # 将高斯噪声应用于输入图像。
], p=0.2),
])
augmented2 = aug2(image=image)
image = augmented2['image']
image = transforms.ToTensor()(image)
label = transforms.ToTensor()(label)
return {'imidx':torch.from_numpy(imidx),'image':image, 'label':label}
使用A.Resize()
来进行缩放/将输入图像调整为给定的高度和宽度的时候,默认参数是A.Resize(height, width, interpolation=1, always_apply=False, p=1)
interpolation的选择可以参考Resizing transforms
因为这里是要缩小,为了避免出现波纹现象,所以最好使用区域插值**cv2.INTER_AREA
**,即interpolation=3,其他参数不变。
根据官方文档应该使用ToTensorV2()
但是结果图上显示图片出现了很多红点,其次不管是ToTensorV2()
orToTensor()
,均出现了通道错误的现象,暂时还没找出问题所在。所以最后还是使用了transforms.ToTensor()
# trans = A.Compose([
# A.HorizontalFlip(p=0.5),
# A.VerticalFlip(p=0.5),
# A.OneOf([
# A.IAAAdditiveGaussianNoise(), # 将高斯噪声添加到输入图像
# A.GaussNoise(), # 将高斯噪声应用于输入图像。
# ], p=0.2), # 应用选定变换的概率
# A.OneOf([
# A.MotionBlur(p=0.2), # 使用随机大小的内核将运动模糊应用于输入图像。
# A.MedianBlur(blur_limit=3, p=0.1), # 中值滤波
# A.Blur(blur_limit=3, p=0.1), # 使用随机大小的内核模糊输入图像。
# ], p=0.2),
# A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
# # 随机应用仿射变换:平移,缩放和旋转输入
# A.RandomBrightnessContrast(p=0.2), # 随机明亮对比度
# ])
3. 制作datasets
代码语言:javascript复制class SalObjDataset(Dataset):
def __init__(self,img_name_list,lbl_name_list,transform=None):
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.transform = transform
def __len__(self):
return len(self.image_name_list)
def __getitem__(self, idx):
image = cv2.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx]) # [idx]
if (len(self.label_name_list) == 0): # inference: label_name_list = []/None
label_3 = np.zeros(image.shape)
else: # train
label_3 = cv2.imread(self.label_name_list[idx]) ###
# 正确读取单通道label
label = np.zeros(label_3.shape[0:2]) # copy zeros shape | just get HW
if (len(label_3.shape) == 3):
label = label_3[:, :, 0] # H*W 1 channel or 到最后再转
elif (len(label_3.shape) == 2):
label = label_3
# make sure label(...,...,1)
if (len(image.shape) == 3) and (len(label.shape) == 2):
label = label[:,:,np.newaxis]
### image BGR2RGB
#image = image[:, :, ::-1] ###
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # opencv的通道转换不要忘记
elif (len(image.shape) == 2) and (len(label.shape == 2)): #
image = image[:,:, np.newaxis]
label = label[:, :, np.newaxis]
sample = {'imidx': imidx, 'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
return sample
4. DataLoader
代码语言:javascript复制# salobj_dataset = SalObjDataset(
# img_name_list=tra_img_name_list,
# lbl_name_list=tra_lbl_name_list,
# transform=transforms.Compose([
# RescaleT(320),
# RandomCrop(288),
# ToTensor()]))
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transformA(scale_size=320, output_size=288))
salobj_dataloader = DataLoader(salobj_dataset,
batch_size=2,
shuffle=True,
num_workers=4,
drop_last=True)
以上主要是以opencv为例子来读取数据并加载的。
关于skimage怎么读取数据加载,数据增强等,可以直接参考U2Net的代码1。上面的代码就是改写自u2net训练和dataloader的代码。
后续PIL的有机会再补充,搞清楚了各种接口之间的关系和相互转换,其本质都是一样的。
5. 测试并显示
代码语言:javascript复制# 辅助函数,用于展示一个 batch 的数据
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch =
sample_batched['image'], sample_batched['label']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid_border_size = 2
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() i * im_size (i 1) * grid_border_size,
landmarks_batch[i, :, 1].numpy() grid_border_size,
s=10, marker='.', c='r')
plt.title('Batch from salobj_dataloader')
for i_batch, sample_batched in enumerate(salobj_dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['label'].size())
# observe 4th batch and stop.
if i_batch == 1:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break
代码语言:javascript复制0 torch.Size([2, 3, 288, 288]) torch.Size([2, 1, 288, 288])
1 torch.Size([2, 3, 288, 288]) torch.Size([2, 1, 288, 288])
代码语言:javascript复制
Reference:
1 https://cloud.tencent.com/developer/article/1660972
2(https://albumentations.ai/docs/getting_started/mask_augmentation/)
3 http://aix.51cto.com/blog/70799.html
4 数据来源:爱分割 github
5(https://blog.csdn.net/qq_27039891/article/details/100795846)
6 https://zhuanlan.zhihu.com/p/371761014
7 https://zhuanlan.zhihu.com/p/107399127
8 https://www.aiuai.cn/aifarm1380.html