最近,pytorch官网发布了一个消息,TorchVision正不断地增加新的接口:
- • 不仅将变换的API用在图像分类上,还用在物体识别、实例分割、语义分割及视频分类领域。
- • 可以从TorchVision的API中直接使用SoTA数据增强方法,如MixUp、CutMix,Large Scale Jitter和SimpleCopyPaste
新的接口目前是测试阶段
前面我写了篇文章《一种目标检测任务中图像-标注对增强方法》,可以去看一下,和TorchVision中的新增功能有些类似。
现有变换的限制
目前的TorchVision V1仅仅支持单张图片,仅能用于分类任务:
代码语言:javascript复制from torchvision import transforms
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs = trans(imgs)
上述方法不支持需要使用标签的物体检测、分割或分类变换(如MixUp & CutMix)。这一限制对任何非分类的计算机视觉任务都不利,因为人们无法使用变换API来进行必要的增强。从历史上看,这使得使用TorchVision来训练高精确度的模型变得很困难,因此我们的模型Zoo比SoTA滞后了几个点。
为了规避这一限制,TorchVision在其参考脚本中提供了自定义的实现方式,展示了如何在每个任务中进行增强处理。尽管这种做法使我们能够训练出高精度的分类、物体检测和分割模型,但这是一种笨拙的方法,使这些变换无法从TorchVision二进制中导入。
新的变换API
Transforms V2 API支持视频、边界框、标签和分割掩码,这意味着它为许多计算机视觉任务提供了本地支持。新的解决方案是一种直接的替换,如
代码语言:javascript复制from torchvision.prototype import transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)
新的转换类可以接收任意数量的输入,而不强制执行特定的顺序或结构。
代码语言:javascript复制# Already supported:
trans(imgs) # Image Classification
trans(videos) # Video Tasks
trans(imgs_or_videos, labels) # MixUp/CutMix-style Transforms
trans(imgs, bboxes, labels) # Object Detection
trans(imgs, bboxes, masks, labels) # Instance Segmentation
trans(imgs, masks) # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels}) # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints) # Keypoint Detection
trans(stereo_images, disparities, masks) # Depth Perception
trans(image1, image2, optical_flows, masks) # Optical Flow
变换类确保对所有输入应用相同的随机变换,以确保结果一致。
功能性API已经更新,以支持所有输入的所有必要的信号处理内核(调整大小、裁剪、仿生变换、填充等)。
代码语言:javascript复制from torchvision.prototype.transforms import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, resize=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, resize=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, resize=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, resize=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, resize=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, resize=[224, 224], spatial_size=[256, 256])
该API使用张量子类来包装输入,附加有用的元数据,并分配给正确的内核。一旦Datasets V2的工作完成,即利用TorchData的数据管道,手动包装输入就没有必要了。目前,用户可以通过以下方式手动包装输入。
代码语言:javascript复制from torchvision.prototype import features
imgs = features.Image(images, color_space=ColorSpace.RGB)
vids = features.Video(videos, color_space=ColorSpace.RGB)
masks = features.Mask(target["masks"])
bboxes = features.BoundingBox(target["boxes"], format=BoundingBoxFormat.XYXY, spatial_size=imgs.spatial_size)
labels = features.Label(target["labels"], categories=["dog", "cat"])
除了新的API,我们现在还为SoTA研究中使用的几种数据增强提供了可导入的实现,如MixUp、CutMix、Large Scale Jitter、SimpleCopyPaste、AutoAugmentation方法和几种新的几何、颜色和类型转换。
该API继续支持图像的PIL和张量后端,单一或批量输入,并保持功能API的JIT脚本性。它允许推迟图像从uint8到float的转换,这可以带来性能上的好处。它目前在TorchVision的原型区可用,可以从夜间构建中导入。新的API已经过验证,达到了与以前的实现相同的精度。
目前的限制
尽管功能API(内核)仍然是可编写JIT脚本和fully-BC的,但转换类虽然提供了相同的接口,却不能编写脚本。这是因为它们使用了张量子类,并接收任意数量的输入,而这是JIT所不支持的。我们目前正在努力减少新API的调度开销,并提高现有内核的速度。
一个端到端的例子
下面是一个使用以下图像的新API的例子。它同时适用于PIL图像和Tensors。
代码语言:javascript复制import PIL
from torchvision import io, utils
from torchvision.prototype import features, transforms as T
from torchvision.prototype.transforms import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = features.Image(io.read_image(path), color_space=features.ColorSpace.RGB)
# img = PIL.Image.open(path)
bboxes = features.BoundingBox(
[[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
[148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
[422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
[435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
[469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
[452, 39, 463, 63], [424, 38, 429, 50]],
format=features.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
labels = features.Label([59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74])
# Defining and applying Transforms V2
trans = T.Compose(
[
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480),
]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()