【目标检测】YOLOv5推理加速实验:图片批量检测

2023-04-12 09:23:45 浏览数 (1)

前言

上篇博文探究了一下多进程是否能够对YOLOv5模型推理起到加速作用,本篇主要来研究一下如果将图片批量送入网络中进行检测,是否能对网络的推理起到加速作用。

YOLOv5批量检测源码解析

YOLOv5在训练过程中是可以进行分批次训练(batch_size>1),然而在默认的推理过程中,却没有预留batch_size的相关接口,仍然只是单张图一张张进行检测推理。难道批检测推理的速度不会更快吗?下面通过实验来探究。

本文所使用的版本为官方仓库的最新版本(v7.0)。

默认单图推理

首先来看看官方源码默认的推理逻辑,在detect.py文件中,数据集通过LoadImages实例化一个类。

代码语言:javascript复制
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)

LoadImages:

代码语言:javascript复制
class LoadImages:
    # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
    def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
        if isinstance(path, str) and Path(path).suffix == '.txt':  # *.txt file with img/vid/dir on each line
            path = Path(path).read_text().rsplit()
        files = []
        for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
            p = str(Path(p).resolve())
            if '*' in p:
                files.extend(sorted(glob.glob(p, recursive=True)))  # glob
            elif os.path.isdir(p):
                files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))  # dir
            elif os.path.isfile(p):
                files.append(p)  # files
            else:
                raise FileNotFoundError(f'{p} does not exist')

        images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
        videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
        ni, nv = len(images), len(videos)

        self.img_size = img_size
        self.stride = stride
        self.files = images   videos
        self.nf = ni   nv  # number of files
        self.video_flag = [False] * ni   [True] * nv
        self.mode = 'image'
        self.auto = auto
        self.transforms = transforms  # optional
        self.vid_stride = vid_stride  # video frame-rate stride
        if any(videos):
            self._new_video(videos[0])  # new video
        else:
            self.cap = None
        assert self.nf > 0, f'No images or videos found in {p}. ' 
                            f'Supported formats are:nimages: {IMG_FORMATS}nvideos: {VID_FORMATS}'

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.nf:
            raise StopIteration
        path = self.files[self.count]

        if self.video_flag[self.count]:
            # Read video
            self.mode = 'video'
            for _ in range(self.vid_stride):
                self.cap.grab()
            ret_val, im0 = self.cap.retrieve()
            while not ret_val:
                self.count  = 1
                self.cap.release()
                if self.count == self.nf:  # last video
                    raise StopIteration
                path = self.files[self.count]
                self._new_video(path)
                ret_val, im0 = self.cap.read()

            self.frame  = 1
            # im0 = self._cv2_rotate(im0)  # for use if cv2 autorotation is False
            s = f'video {self.count   1}/{self.nf} ({self.frame}/{self.frames}) {path}: '

        else:
            # Read image
            self.count  = 1
            im0 = cv2.imread(path)  # BGR
            assert im0 is not None, f'Image Not Found {path}'
            s = f'image {self.count}/{self.nf} {path}: '

        if self.transforms:
            im = self.transforms(im0)  # transforms
        else:
            im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0]  # padded resize
            im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
            im = np.ascontiguousarray(im)  # contiguous

        return path, im, im0, self.cap, s

该类实例化之后,首先是根据传入的path类型来判断是单张图片还是文件夹。

然后在循环过程中,执行__next__方法,此时开始读取文件,并进行Inference、NMS等后续操作。

代码语言:javascript复制
for path, im, im0s, vid_cap, s in dataset:
    with dt[0]:
        im = torch.from_numpy(im).to(model.device)
        im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
        im /= 255  # 0 - 255 to 0.0 - 1.0
        if len(im.shape) == 3:
            im = im[None]  # expand for batch dim

    # Inference
    with dt[1]:
        visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
        pred = model(im, augment=augment, visualize=visualize)

    # NMS
    with dt[2]:
        pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

因此,不难发现,原始的detect.py只使用了单图进行推理。

多图推理构建

基本没见到有人做过多图推理的改造探索,在官方仓库的某issue中,找到了作者提供的一种调用思路,稍作改造,代码如下:

代码语言:javascript复制
import os
import cv2
import torch
import time

s_t = time.time()
# Model
model = torch.hub.load('D:/Desktop/yolov5-master', 'custom', 'yolov5s.pt', source='local')
img_list = []
dir_path = "data/images"
for i in os.listdir(dir_path):
    img_list.append(cv2.imread(dir_path   '/'   i)[..., ::-1])

# Inference
results = model(img_list, size=1280)  # batch of images
# Results
results.print()
results.save()
print("Cost Time:", time.time() - s_t)

该方法主要调用的是hubconf.py里面的内容,上面是直接将整个文件夹中的所有图片变成一个batch,输送到网络中进行检测。

实际进行检测的代码块在yolo.py文件中的_forward_once方法。

代码语言:javascript复制
class BaseModel(nn.Module):
    def _forward_once(self, x, profile=False, visualize=False):
        y, dt = [], []  # outputs
        for m in self.model:
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            x = m(x)  # run
            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
        return x

这里的x就是输入的Tensor,m是模型的每一层结构,这里不断将输入循环到下一层,实现了网络的批量推理。

速度比较

下面使用RTX4090单卡进行速度测试,数据集选用VisDrone的部分数据,模型选择YOLOv5s:

测试结果如下表所示:

图片数量

直接检测花费时间(s)

批量检测花费时间(s)

100

3.767014265060425

3.9564051628112793

200

5.948423385620117

6.09602165222168

注:我这里的批量检测是直接将所有的图片变成一个batch,200张图片之后,显存基本被占满,更多图像就没有进行测试。

从结果可见,批量检测并没有预期的速度提升,反而比直接单张检测更慢。估计这也是为什么官方不在detect中预留多个batch检测接口的原因。

0 人点赞