一种目标检测任务中图像-标注对增强方法

2022-12-07 10:22:00 浏览数 (2)

其实,本篇应是深度学习常用图像数据增强库albumentations系列教程(三)的,但是鉴于不如现在的题目直观,还是修改了,原来两篇见如下:

深度学习常用图像数据增强库albumentations系列教程(一)

深度学习常用图像数据增强库albumentations系列教程(二)

本篇是在前面两篇基础上,对目标检测任务中常用的包围框标注数据进行增强。

1. 目标检测任务包围框

目标检测任务中在训练之前要对图像中的目标物体进行标注,比如使用labelimg对目标物体的位置和类别进行标注,生成xml文件(数据是pascal_voc格式)。

albumentations支持四种数据格式: pascal_voc,albumentations, coco和yolo,这四种数据格式使用不同的方法表示包围框的位置。

  • • pascal_voc: 使用[x_min, y_min, x_max, y_max]描述包围框。x_min和y_min是包围框左上角的坐标,y_min和y_max是右下角的坐标,如[138, 103, 161, 471]
  • • albumentations: 使用[x_min, y_min, x_max, y_max]表示,和pascal_voc不同的是albumentations用的是归一化的值去描述,即将横纵坐标除以相应的长宽,如[138/640, 103/480, 161/640, 471/480]
  • • coco: 使用[x_min, y_min, width, height]表示包围框,如[138, 103, 23, 368]
  • • yolo: 使用[x_center, y_center, width, height],前面两个参数是规范化后的包围框的中心位置,如[((138 161)/2)/640, ((103 471)/2)/480, 23/640, 368/480]

2. 目标检测任务图像-标注对数据增强功能实现

针对训练样本量少的情况,我们常常会使用数据增强的方法增加样本量,如图像的旋转、平移、缩放、改变亮度等,针对增强后的图像常常还需要标注,标注工作量较大。尽管有些方法在训练的时候会帮你实现这些功能,我个人还是习惯将标注增强直观展示,确定标注增强的合理性。

图像-标注对增强包括如下流程:

  1. 1. 利用单张或者多张图像进行标注,生成xml文件
  2. 2. 定义增强pipeline
  3. 3. 从文件夹中遍历原始的图像文件和xml文件
  4. 4. 通过增强pipeline得到图像标注增强对用于训练

注意:不是所有的变换都支持包围框标注数据增强的,目前(20220921)支持包围框增强的变换

代码语言:javascript复制
import random
import cv2
from matplotlib import pyplot as plt
import xml.etree.ElementTree as ET
import albumentations as A
import os
import time
import glob
from tqdm import trange

BOX_COLOR = (255, 0, 0)  # Red
TEXT_COLOR = (255, 255, 255)  # White
# original pictures size:62, then total size is 62*GENERATED_PICS_SIZE
GENERATED_PICS_SIZE = 600


def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=2):
    """Visualizes a single bounding box on the image"""
    # x_min, y_min, w, h = bbox
    # x_min, x_max, y_min, y_max = int(x_min), int(x_min   w), int(y_min), int(
    #     y_min   h)
    x_min, y_min, x_max, y_max = bbox
    print(x_min, y_min, x_max, y_max)

    cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)),
                  color=color, thickness=thickness)

    ((text_width, text_height), _) = cv2.getTextSize(class_name,
                                                     cv2.FONT_HERSHEY_SIMPLEX,
                                                     0.35, 1)
    cv2.rectangle(img, (int(x_min), int(y_min) - int(1.3 * text_height)),
                  (int(x_min)   text_width, int(y_min)), BOX_COLOR, -1)
    cv2.putText(
        img,
        text=class_name,
        org=(int(x_min), int(y_min) - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35,
        color=TEXT_COLOR,
        lineType=cv2.LINE_AA,
    )
    return img


def visualize(image, bboxes, category_ids, category_id_to_name):
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
        class_name = category_id_to_name[category_id]
        img = visualize_bbox(img, bbox, class_name)
    plt.axis('off')
    plt.imshow(img)
    plt.show()


def saveNewAnnotation(new_xml_path, new_jpg_path, xml_path, bboxes, cur_dir):
    in_file = open(os.path.join(xml_path), encoding='utf-8')
    new_file = in_file
    tree = ET.parse(new_file)
    root = tree.getroot()
    root[0].text = "annotation_out"
    root[1].text = new_jpg_path
    root[2].text = cur_dir   '\annotation_out\'   new_jpg_path

    idx = 0
    for obj in root.iter('object'):
        obj[4][0].text = str(round(bboxes[idx][0]))
        obj[4][1].text = str(round(bboxes[idx][1]))
        obj[4][2].text = str(round(bboxes[idx][2]))
        obj[4][3].text = str(round(bboxes[idx][3]))
        idx  = 1
    tree.write(new_xml_path, 'UTF-8')


def getAnnotation(xml_path):
    '''
    :param xml_path:
    :return: bboxes, category_ids
    '''

    in_file = open(os.path.join(xml_path), encoding='utf-8')
    try:
        tree = ET.parse(in_file)
    except:
        return [], []
    root = tree.getroot()

    bboxes = []
    category_ids = []

    for obj in root.iter('object'):
        cls = obj.find('name').text

        xmlbox = obj.find('bndbox')
        bbox = [int(float(xmlbox.find('xmin').text)),
                int(float(xmlbox.find('ymin').text)),
                int(float(xmlbox.find('xmax').text)),
                int(float(xmlbox.find('ymax').text))]
        bboxes.append(bbox)
        category_ids.append(cls)
    return bboxes, category_ids


def main(cur_dir):
    PICS_PATH = 'annotation_ori'  # 存放图片的文件夹路径
    paths = glob.glob(os.path.join(PICS_PATH, '*.jpg'))
    for i in trange(len(paths)):
        jpg_path = paths[i]
        xml_path = jpg_path.split('.')[0]   ".xml"

        # print(jpg_path.split('.'))
        new_jpg_path_prefix = 'annotation_out\'

        for i in range(GENERATED_PICS_SIZE):
            image = cv2.imread(jpg_path)
            # print(width, ", ", height)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            new_jpg_path = jpg_path.split('.')[0].split('\')[-1]   str(
                i   1).zfill(4)   ".jpg"
            new_xml_path = new_jpg_path_prefix   
                           jpg_path.split('.')[0].split('\')[-1]   str(
                i   1).zfill(4)   ".xml"
            bboxes, category_ids = getAnnotation(xml_path=xml_path)
            if len(bboxes) == 0 & len(category_ids) == 0:
                continue
            category_id_to_name = {}
            for i in range(len(category_ids)):
                category_id_to_name[category_ids[i]] = category_ids[i]
            # 变换操作
            # 水平反转,高斯模糊,gamma变换,亮度变化,
            transform = A.Compose(
                [
                    A.HorizontalFlip(p=0.5),
                    A.Rotate(limit=2, p=0.3),
                    A.ShiftScaleRotate(shift_limit=0.0625,scale_limit=0, rotate_limit=0,p=0.3),
                    A.GaussianBlur(blur_limit=1, p=0.5),
                    A.ColorJitter(brightness=0.05, contrast=0.05,
                                  saturation=0.02,
                                  hue=0.02, always_apply=False, p=1)
                ],
                bbox_params=A.BboxParams(format='pascal_voc',
                                         label_fields=['category_ids']),
            )
            transformed = transform(image=image, bboxes=bboxes,
                                    category_ids=category_ids)
            image = transformed['image']
            bboxes = transformed['bboxes']
            category_ids = transformed['category_ids']
            # print(bboxes)
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            cv2.imencode('.jpg', image)[1].tofile("annotation_out\"   new_jpg_path)
            # visualize(image, bboxes, category_ids, category_id_to_name)
            saveNewAnnotation(new_xml_path, new_jpg_path, xml_path, bboxes, cur_dir)
        time.sleep(1)

if __name__ == '__main__':
    import os
    cur_dir = os.path.dirname(os.path.abspath(__file__))  # 上级目录
    print(cur_dir)
    main(cur_dir)

通过上述代码,我们会生成大量基于原始图像-标注对的衍生图像-标注对。

建议先生成少量图像-标注对,然后使用labelimg查看下生成的图像-标注对是否正确,确定无问题后,再生成大量图片-标注对

0 人点赞