yolo-world 源码解析(五)

2024-03-09 08:50:41 浏览数 (1)

.YOLO-Worldyolo_worlddatasetstransformersmm_transforms.py

代码语言:javascript复制
# 导入所需的库
import json
import random
from typing import Tuple

import numpy as np
from mmyolo.registry import TRANSFORMS

# 注册 RandomLoadText 类为 TRANSFORMS 模块
@TRANSFORMS.register_module()
class RandomLoadText:

    def __init__(self,
                 text_path: str = None,
                 prompt_format: str = '{}',
                 num_neg_samples: Tuple[int, int] = (80, 80),
                 max_num_samples: int = 80,
                 padding_to_max: bool = False,
                 padding_value: str = '') -> None:
        # 初始化 RandomLoadText 类的属性
        self.prompt_format = prompt_format
        self.num_neg_samples = num_neg_samples
        self.max_num_samples = max_num_samples
        self.padding_to_max = padding_to_max
        self.padding_value = padding_value
        # 如果指定了 text_path,则读取对应文件内容
        if text_path is not None:
            with open(text_path, 'r') as f:
                self.class_texts = json.load(f)

# 注册 LoadText 类为 TRANSFORMS 模块
@TRANSFORMS.register_module()
class LoadText:

    def __init__(self,
                 text_path: str = None,
                 prompt_format: str = '{}',
                 multi_prompt_flag: str = '/') -> None:
        # 初始化 LoadText 类的属性
        self.prompt_format = prompt_format
        self.multi_prompt_flag = multi_prompt_flag
        # 如果指定了 text_path,则读取对应文件内容
        if text_path is not None:
            with open(text_path, 'r') as f:
                self.class_texts = json.load(f)

    # 定义 __call__ 方法,用于处理结果字典
    def __call__(self, results: dict) -> dict:
        # 检查结果字典中是否包含 'texts' 键或者类属性中是否包含 'class_texts'
        assert 'texts' in results or hasattr(self, 'class_texts'), (
            'No texts found in results.')
        # 获取类属性中的 'class_texts' 或者结果字典中的 'texts'
        class_texts = results.get(
            'texts',
            getattr(self, 'class_texts', None))

        texts = []
        # 遍历类别文本列表,处理每个类别文本
        for idx, cls_caps in enumerate(class_texts):
            assert len(cls_caps) > 0
            sel_cls_cap = cls_caps[0]
            sel_cls_cap = self.prompt_format.format(sel_cls_cap)
            texts.append(sel_cls_cap)

        # 将处理后的文本列表存入结果字典中的 'texts' 键
        results['texts'] = texts

        return results

.YOLO-Worldyolo_worlddatasetstransformers__init__.py

代码语言:javascript复制
# 导入腾讯公司的所有权声明
# 从当前目录下的 mm_transforms 模块中导入 RandomLoadText 和 LoadText 类
# 从当前目录下的 mm_mix_img_transforms 模块中导入 MultiModalMosaic、MultiModalMosaic9、YOLOv5MultiModalMixUp、YOLOXMultiModalMixUp 类
# 定义 __all__ 列表,包含需要导出的类名
__all__ = ['RandomLoadText', 'LoadText', 'MultiModalMosaic',
           'MultiModalMosaic9', 'YOLOv5MultiModalMixUp',
           'YOLOXMultiModalMixUp']

.YOLO-Worldyolo_worlddatasetsutils.py

代码语言:javascript复制
# 导入必要的库和模块
from typing import Sequence
import torch
from mmengine.dataset import COLLATE_FUNCTIONS

# 注册自定义的数据集拼接函数
@COLLATE_FUNCTIONS.register_module()
def yolow_collate(data_batch: Sequence,
                  use_ms_training: bool = False) -> dict:
    """Rewrite collate_fn to get faster training speed.

    Args:
       data_batch (Sequence): Batch of data.
       use_ms_training (bool): Whether to use multi-scale training.
    """
    # 初始化空列表用于存储数据
    batch_imgs = []
    batch_bboxes_labels = []
    batch_masks = []
    
    # 遍历数据批次
    for i in range(len(data_batch)):
        datasamples = data_batch[i]['data_samples']
        inputs = data_batch[i]['inputs']
        batch_imgs.append(inputs)

        # 获取 ground truth 边界框和标签
        gt_bboxes = datasamples.gt_instances.bboxes.tensor
        gt_labels = datasamples.gt_instances.labels
        
        # 如果数据中包含 masks,则转换为张量并添加到 batch_masks 列表中
        if 'masks' in datasamples.gt_instances:
            masks = datasamples.gt_instances.masks.to_tensor(
                dtype=torch.bool, device=gt_bboxes.device)
            batch_masks.append(masks)
        
        # 创建 batch_idx 用于标识数据批次,拼接边界框和标签
        batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
        bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
                                  dim=1)
        batch_bboxes_labels.append(bboxes_labels)

    # 构建拼接后的结果字典
    collated_results = {
        'data_samples': {
            'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
        }
    }
    
    # 如果存在 masks 数据,则添加到结果字典中
    if len(batch_masks) > 0:
        collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)

    # 根据是否使用多尺度训练,将输入数据添加到结果字典中
    if use_ms_training:
        collated_results['inputs'] = batch_imgs
    else:
        collated_results['inputs'] = torch.stack(batch_imgs, 0)

    # 如果数据中包含文本信息,则添加到结果字典中
    if hasattr(data_batch[0]['data_samples'], 'texts'):
        batch_texts = [meta['data_samples'].texts for meta in data_batch]
        collated_results['data_samples']['texts'] = batch_texts
    # 检查第一个数据批次中的'data_samples'是否具有'is_detection'属性
    if hasattr(data_batch[0]['data_samples'], 'is_detection'):
        # 如果具有'data_samples'中的'is_detection'属性,则提取每个数据批次中'data_samples'的'is_detection'值
        batch_detection = [meta['data_samples'].is_detection
                           for meta in data_batch]
        # 将提取的'data_samples'中的'is_detection'值转换为torch张量,并存储在collated_results字典中
        collated_results['data_samples']['is_detection'] = torch.tensor(
            batch_detection)

    # 返回整理后的结果字典
    return collated_results

.YOLO-Worldyolo_worlddatasetsyolov5_lvis.py

代码语言:javascript复制
# 导入需要的模块
from mmdet.datasets import LVISV1Dataset

# 导入自定义的数据集类
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS

# 注册YOLOv5 LVIS数据集类,继承自BatchShapePolicyDataset和LVISV1Dataset
@DATASETS.register_module()
class YOLOv5LVISV1Dataset(BatchShapePolicyDataset, LVISV1Dataset):
    """Dataset for YOLOv5 LVIS Dataset.

    We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
    See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
    """
    # 空的类定义,没有额外的方法或属性
    pass

.YOLO-Worldyolo_worlddatasetsyolov5_mixed_grounding.py

代码语言:javascript复制
# 导入必要的模块
import os.path as osp
from typing import List, Union

# 导入自定义模块
from mmengine.fileio import get_local_path, join_path
from mmengine.utils import is_abs
from mmdet.datasets.coco import CocoDataset
from mmyolo.registry import DATASETS
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset

# 注册YOLOv5MixedGroundingDataset类为DATASETS
@DATASETS.register_module()
class YOLOv5MixedGroundingDataset(BatchShapePolicyDataset, CocoDataset):
    """Mixed grounding dataset."""

    # 定义元信息
    METAINFO = {
        'classes': ('object',),
        'palette': [(220, 20, 60)]}

    # 加载数据列表
    def load_data_list(self) -> List[dict]:
        """Load annotations from an annotation file named as ``self.ann_file``

        Returns:
            List[dict]: A list of annotation.
        """  # noqa: E501
        # 使用get_local_path函数获取本地路径
        with get_local_path(
                self.ann_file, backend_args=self.backend_args) as local_path:
            # 使用COCOAPI加载本地路径的数据
            self.coco = self.COCOAPI(local_path)

        # 获取图像ID列表
        img_ids = self.coco.get_img_ids()
        data_list = []
        total_ann_ids = []
        for img_id in img_ids:
            # 加载原始图像信息
            raw_img_info = self.coco.load_imgs([img_id])[0]
            raw_img_info['img_id'] = img_id

            # 获取图像对应的注释ID列表
            ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
            raw_ann_info = self.coco.load_anns(ann_ids)
            total_ann_ids.extend(ann_ids)

            # 解析数据信息
            parsed_data_info = self.parse_data_info({
                'raw_ann_info':
                raw_ann_info,
                'raw_img_info':
                raw_img_info
            })
            data_list.append(parsed_data_info)
        # 检查注释ID是否唯一
        if self.ANN_ID_UNIQUE:
            assert len(set(total_ann_ids)) == len(
                total_ann_ids
            ), f"Annotation ids in '{self.ann_file}' are not unique!"

        # 删除self.coco对象
        del self.coco
        # 返回数据列表
        return data_list
    def filter_data(self) -> List[dict]:
        """Filter annotations according to filter_cfg.

        Returns:
            List[dict]: Filtered results.
        """
        # 如果处于测试模式,则直接返回原始数据列表
        if self.test_mode:
            return self.data_list

        # 如果没有设置过滤配置,则直接返回原始数据列表
        if self.filter_cfg is None:
            return self.data_list

        # 获取过滤空标注和最小尺寸的配置参数
        filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
        min_size = self.filter_cfg.get('min_size', 0)

        # 获取包含标注的图片的 ID 集合
        ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)

        valid_data_infos = []
        # 遍历数据列表,筛选符合条件的数据信息
        for i, data_info in enumerate(self.data_list):
            img_id = data_info['img_id']
            width = int(data_info['width'])
            height = int(data_info['height'])
            # 如果设置了过滤空标注并且当前图片没有标注,则跳过
            if filter_empty_gt and img_id not in ids_with_ann:
                continue
            # 如果图片宽高中的最小值大于等于最小尺寸,则将该数据信息添加到有效数据列表中
            if min(width, height) >= min_size:
                valid_data_infos.append(data_info)

        # 返回筛选后的有效数据信息列表
        return valid_data_infos
    # 将 self.data_root 与 self.data_prefix 和 self.ann_file 连接起来
    def _join_prefix(self):
        """Join ``self.data_root`` with ``self.data_prefix`` and
        ``self.ann_file``.
        """
        # 如果 self.ann_file 不是绝对路径且 self.data_root 存在,则自动将注释文件路径与 self.root 连接起来
        if self.ann_file and not is_abs(self.ann_file) and self.data_root:
            self.ann_file = join_path(self.data_root, self.ann_file)
        # 如果 self.data_prefix 中的路径值不是绝对路径,则自动将数据目录与 self.root 连接起来
        for data_key, prefix in self.data_prefix.items():
            if isinstance(prefix, (list, tuple)):
                abs_prefix = []
                for p in prefix:
                    if not is_abs(p) and self.data_root:
                        abs_prefix.append(join_path(self.data_root, p))
                    else:
                        abs_prefix.append(p)
                self.data_prefix[data_key] = abs_prefix
            elif isinstance(prefix, str):
                if not is_abs(prefix) and self.data_root:
                    self.data_prefix[data_key] = join_path(
                        self.data_root, prefix)
                else:
                    self.data_prefix[data_key] = prefix
            else:
                raise TypeError('prefix should be a string, tuple or list,'
                                f'but got {type(prefix)}')

.YOLO-Worldyolo_worlddatasetsyolov5_obj365v1.py

代码语言:javascript复制
# 导入需要的模块
from mmdet.datasets import Objects365V1Dataset

# 导入自定义的数据集类
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS

# 注册YOLOv5Objects365V1Dataset类到DATASETS模块
@DATASETS.register_module()
class YOLOv5Objects365V1Dataset(BatchShapePolicyDataset, Objects365V1Dataset):
    """Dataset for YOLOv5 VOC Dataset.

    We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
    See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
    """
    pass

.YOLO-Worldyolo_worlddatasetsyolov5_obj365v2.py

代码语言:javascript复制
# 导入 Objects365V2Dataset 类
from mmdet.datasets import Objects365V2Dataset

# 导入 BatchShapePolicyDataset 类
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
# 导入 DATASETS 注册表
from mmyolo.registry import DATASETS

# 注册 YOLOv5Objects365V2Dataset 类到 DATASETS 注册表
@DATASETS.register_module()
class YOLOv5Objects365V2Dataset(BatchShapePolicyDataset, Objects365V2Dataset):
    """Dataset for YOLOv5 VOC Dataset.

    We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
    See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
    """
    # 空的类定义,继承自 BatchShapePolicyDataset 和 Objects365V2Dataset
    pass

.YOLO-Worldyolo_worlddatasetsyolov5_v3det.py

代码语言:javascript复制
# 导入所需的模块和函数
import copy
import json
import os.path as osp
from typing import List

from mmengine.fileio import get_local_path

from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets import CocoDataset

from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from mmyolo.registry import DATASETS

# 定义需要忽略的文件列表
v3det_ignore_list = [
    'a00013820/26_275_28143226914_ff3a247c53_c.jpg',
    'n03815615/12_1489_32968099046_be38fa580e_c.jpg',
    'n04550184/19_1480_2504784164_ffa3db8844_c.jpg',
    'a00008703/2_363_3576131784_dfac6fc6ce_c.jpg',
    'n02814533/28_2216_30224383848_a90697f1b3_c.jpg',
    'n12026476/29_186_15091304754_5c219872f7_c.jpg',
    'n01956764/12_2004_50133201066_72e0d9fea5_c.jpg',
    'n03785016/14_2642_518053131_d07abcb5da_c.jpg',
    'a00011156/33_250_4548479728_9ce5246596_c.jpg',
    'a00009461/19_152_2792869324_db95bebc84_c.jpg',
]

# 注册 V3DetDataset 类
@DATASETS.register_module()
class V3DetDataset(CocoDataset):
    """Objects365 v1 dataset for detection."""

    METAINFO = {'classes': 'classes', 'palette': None}

    COCOAPI = COCO
    # ann_id is unique in coco dataset.
    ANN_ID_UNIQUE = True

# 注册 YOLOv5V3DetDataset 类,继承自 BatchShapePolicyDataset 和 V3DetDataset
@DATASETS.register_module()
class YOLOv5V3DetDataset(BatchShapePolicyDataset, V3DetDataset):
    """Dataset for YOLOv5 VOC Dataset.

    We only add `BatchShapePolicy` function compared with Objects365V1Dataset.
    See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
    """
    pass

.YOLO-Worldyolo_worlddatasets__init__.py

代码语言:javascript复制
# 导入所需的模块和类
from .mm_dataset import (
    MultiModalDataset, MultiModalMixedDataset)
from .yolov5_obj365v1 import YOLOv5Objects365V1Dataset
from .yolov5_obj365v2 import YOLOv5Objects365V2Dataset
from .yolov5_mixed_grounding import YOLOv5MixedGroundingDataset
from .utils import yolow_collate
from .transformers import *  # NOQA
from .yolov5_v3det import YOLOv5V3DetDataset
from .yolov5_lvis import YOLOv5LVISV1Dataset

# 定义导出的模块和类列表
__all__ = [
    'MultiModalDataset', 'YOLOv5Objects365V1Dataset',
    'YOLOv5Objects365V2Dataset', 'YOLOv5MixedGroundingDataset',
    'YOLOv5V3DetDataset', 'yolow_collate',
    'YOLOv5LVISV1Dataset', 'MultiModalMixedDataset',
]

.YOLO-Worldyolo_worldengineoptimizersyolow_v5_optim_constructor.py

代码语言:javascript复制
# 版权声明,版权归腾讯公司所有
import logging
from typing import List, Optional, Union

import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
from mmengine.dist import get_world_size
from mmengine.logging import print_log
from mmengine.optim import OptimWrapper, DefaultOptimWrapperConstructor
from mmengine.utils.dl_utils import mmcv_full_available
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm

from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
                             OPTIMIZERS)

# 注册优化器包装器构造函数
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class YOLOWv5OptimizerConstructor(DefaultOptimWrapperConstructor):
    """YOLO World v5 constructor for optimizers."""

    # 初始化函数,接受优化器包装器配置和参数配置
    def __init__(self,
                 optim_wrapper_cfg: dict,
                 paramwise_cfg: Optional[dict] = None) -> None:
        # 调用父类的初始化函数
        super().__init__(optim_wrapper_cfg, paramwise_cfg)
        # 从参数配置中弹出'base_total_batch_size',默认值为64
        self.base_total_batch_size = self.paramwise_cfg.pop(
            'base_total_batch_size', 64)
    # 定义一个方法,用于为模型创建优化器包装器
    def __call__(self, model: nn.Module) -> OptimWrapper:
        # 如果模型有'module'属性,则将'module'属性赋值给model
        if hasattr(model, 'module'):
            model = model.module

        # 复制优化器包装器配置
        optim_wrapper_cfg = self.optim_wrapper_cfg.copy()
        # 设置默认的优化器包装器类型为'OptimWrapper'
        optim_wrapper_cfg.setdefault('type', 'OptimWrapper')
        # 复制优化器配置
        optimizer_cfg = self.optimizer_cfg.copy()

        # 遵循原始的yolov5实现
        if 'batch_size_per_gpu' in optimizer_cfg:
            # 弹出'batch_size_per_gpu'键值对,并赋值给batch_size_per_gpu
            batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
            # 计算总批量大小
            total_batch_size = get_world_size() * batch_size_per_gpu
            # 计算累积步数
            accumulate = max(
                round(self.base_total_batch_size / total_batch_size), 1)
            # 计算缩放因子
            scale_factor = total_batch_size * 
                accumulate / self.base_total_batch_size

            # 如果缩放因子不等于1
            if scale_factor != 1:
                # 获取优化器配置中的权重衰减值
                weight_decay = optimizer_cfg.get('weight_decay', 0)
                # 根据缩放因子调整权重衰减值
                weight_decay *= scale_factor
                optimizer_cfg['weight_decay'] = weight_decay
                # 打印调整后的权重衰减值
                print_log(f'Scaled weight_decay to {weight_decay}', 'current')

        # 如果没有指定paramwise选项,则使用全局设置
        if not self.paramwise_cfg:
            # 将模型的参数设置为优化器配置的参数
            optimizer_cfg['params'] = model.parameters()
            # 构建优化器
            optimizer = OPTIMIZERS.build(optimizer_cfg)
        else:
            # 递归设置参数的学习率和权重衰减
            params: List = []
            self.add_params(params, model)
            optimizer_cfg['params'] = params
            optimizer = OPTIMIZERS.build(optimizer_cfg)
        # 构建优化器包装器
        optim_wrapper = OPTIM_WRAPPERS.build(
            optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
        # 返回优化器包装器
        return optim_wrapper

.YOLO-Worldyolo_worldengineoptimizers__init__.py

代码语言:javascript复制
# 版权声明,版权归腾讯公司所有
# 导入 YOLOWv5 优化器构造器模块
from .yolow_v5_optim_constructor import YOLOWv5OptimizerConstructor

# 导出 YOLOWv5 优化器构造器类,供外部模块使用
__all__ = ['YOLOWv5OptimizerConstructor']

.YOLO-Worldyolo_worldengine__init__.py

代码语言:javascript复制
# 版权声明,版权归腾讯公司所有
# 导入所有优化器模块
from .optimizers import *  # noqa

.YOLO-Worldyolo_worldmodelsbackbonesmm_backbone.py

代码语言:javascript复制
# 导入所需的库
import itertools
from typing import List, Sequence, Tuple
import torch
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmengine.model import BaseModule
from mmyolo.registry import MODELS
from mmdet.utils import OptMultiConfig, ConfigType
from transformers import (AutoTokenizer, AutoModel, CLIPTextConfig)
from transformers import CLIPTextModelWithProjection as CLIPTP

# 注册模型类到模型注册表中
@MODELS.register_module()
class HuggingVisionBackbone(BaseModule):
    # 初始化函数
    def __init__(self,
                 model_name: str,
                 out_indices: Sequence[int] = (0, 1, 2, 3),
                 norm_eval: bool = True,
                 frozen_modules: Sequence[str] = (),
                 init_cfg: OptMultiConfig = None) -> None:

        # 调用父类的初始化函数
        super().__init__(init_cfg=init_cfg)

        # 初始化属性
        self.norm_eval = norm_eval
        self.frozen_modules = frozen_modules
        self.model = AutoModel.from_pretrained(model_name)

        # 冻结指定模块
        self._freeze_modules()

    # 前向传播函数
    def forward(self, image: Tensor) -> Tuple[Tensor]:
        # 获取图像的编码字典
        encoded_dict = self.image_model(pixel_values=image,
                                        output_hidden_states=True)
        hidden_states = encoded_dict.hidden_states
        img_feats = encoded_dict.get('reshaped_hidden_states', hidden_states)
        img_feats = [img_feats[i] for i in self.image_out_indices]
        return tuple(img_feats)

    # 冻结指定模块的参数
    def _freeze_modules(self):
        for name, module in self.model.named_modules():
            for frozen_name in self.frozen_modules:
                if name.startswith(frozen_name):
                    module.eval()
                    for param in module.parameters():
                        param.requires_grad = False
                    break
    # 定义一个训练方法,设置模式为训练或评估
    def train(self, mode=True):
        # 调用父类的train方法,设置当前模型为训练或评估模式
        super().train(mode)
        # 冻结模型的参数
        self._freeze_modules()
        # 如果是训练模式并且开启了norm_eval
        if mode and self.norm_eval:
            # 遍历模型的所有子模块
            for m in self.modules():
                # 如果当前模块是BatchNorm类型
                if isinstance(m, _BatchNorm):
                    # 将当前BatchNorm模块设置为评估模式
                    m.eval()
# 注册 HuggingCLIPLanguageBackbone 类到 MODELS 模块
@MODELS.register_module()
class HuggingCLIPLanguageBackbone(BaseModule):
    # 初始化方法,接受模型名称、冻结模块、dropout 等参数
    def __init__(self,
                 model_name: str,
                 frozen_modules: Sequence[str] = (),
                 dropout: float = 0.0,
                 training_use_cache: bool = False,
                 init_cfg: OptMultiConfig = None) -> None:
        # 调用父类的初始化方法
        super().__init__(init_cfg=init_cfg)
        
        # 设置冻结模块和是否使用缓存的属性
        self.frozen_modules = frozen_modules
        self.training_use_cache = training_use_cache
        # 根据模型名称创建 tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # 根据模型名称和 dropout 创建 CLIPTextConfig 对象
        clip_config = CLIPTextConfig.from_pretrained(model_name, attention_dropout=dropout)
        # 根据模型名称和配置创建 CLIPTP 模型
        self.model = CLIPTP.from_pretrained(model_name, config=clip_config)
        # 冻结指定模块
        self._freeze_modules()

    # 前向传播方法,用于缓存文本数据
    def forward_cache(self, text: List[List[str]]) -> Tensor:
        # 如果不存在缓存,则调用 forward_text 方法生成缓存
        if not hasattr(self, "cache"):
            self.cache = self.forward_text(text)
        return self.cache

    # 前向传播方法,根据训练状态选择使用缓存或者重新计算
    def forward(self, text: List[List[str]]) -> Tensor:
        # 如果处于训练状态,则重新计算文本数据
        if self.training:
            return self.forward_text(text)
        # 否则使用缓存数据
        else:
            return self.forward_cache(text)

    # 前向传播方法,用于处理文本数据并返回处理后的数据
    def forward_tokenizer(self, texts):
        # 如果不存在文本数据,则处理文本数据
        if not hasattr(self, 'text'):
            # 将多个文本列表合并成一个文本列表
            text = list(itertools.chain(*texts))
            # 使用 tokenizer 处理文本数据并转换为 PyTorch 张量
            text = self.tokenizer(text=text, return_tensors='pt', padding=True)
            # 将处理后的文本数据保存到对象属性中
            self.text = text.to(device=self.model.device)
        return self.text
    # 前向传播文本数据,返回文本特征张量
    def forward_text(self, text: List[List[str]]) -> Tensor:
        # 计算每个批次中的序列数量
        num_per_batch = [len(t) for t in text]
        # 断言每个批次中的序列数量相等
        assert max(num_per_batch) == min(num_per_batch), (
            'number of sequences not equal in batch')
        # 将文本列表展开为一维列表
        text = list(itertools.chain(*text))
        # 使用分词器对文本进行处理
        text = self.tokenizer(text=text, return_tensors='pt', padding=True)
        # 将文本数据移动到指定设备上
        text = text.to(device=self.model.device)
        # 获取文本输出
        txt_outputs = self.model(**text)
        # 获取文本特征
        txt_feats = txt_outputs.text_embeds
        # 对文本特征进行归一化处理
        txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
        # 重新调整文本特征的形状
        txt_feats = txt_feats.reshape(-1, num_per_batch[0], txt_feats.shape[-1])
        return txt_feats

    # 冻结指定模块
    def _freeze_modules(self):

        if len(self.frozen_modules) == 0:
            # 如果没有需要冻结的模块,则直接返回
            return
        if self.frozen_modules[0] == "all":
            # 如果需要冻结所有模块,则将所有模块设为评估模式并冻结参数
            self.model.eval()
            for _, module in self.model.named_modules():
                module.eval()
                for param in module.parameters():
                    param.requires_grad = False
            return
        # 遍历模型的所有模块,冻结指定的模块
        for name, module in self.model.named_modules():
            for frozen_name in self.frozen_modules:
                if name.startswith(frozen_name):
                    module.eval()
                    for param in module.parameters():
                        param.requires_grad = False
                    break

    # 训练模型,设置模式并冻结指定模块
    def train(self, mode=True):
        super().train(mode)
        self._freeze_modules()
# 注册PseudoLanguageBackbone类到MODELS模块
@MODELS.register_module()
class PseudoLanguageBackbone(BaseModule):
    """Pseudo Language Backbone
    Args:
        text_embed_path (str): path to the text embedding file
    """
    # 初始化函数,接受文本嵌入文件路径和初始化配置
    def __init__(self,
                 text_embed_path: str = "",
                 test_embed_path: str = None,
                 init_cfg: OptMultiConfig = None):
        # 调用父类的初始化函数
        super().__init__(init_cfg)
        # 加载文本嵌入文件,存储为{text:embed}形式
        self.text_embed = torch.load(text_embed_path, map_location='cpu')
        # 如果测试嵌入文件路径为空,则使用文本嵌入文件
        if test_embed_path is None:
            self.test_embed = self.text_embed
        else:
            self.test_embed = torch.load(test_embed_path)
        # 注册缓冲区
        self.register_buffer("buff", torch.zeros([
            1,
        ]))

    # 缓存前向传播结果
    def forward_cache(self, text: List[List[str]]) -> Tensor:
        if not hasattr(self, "cache"):
            self.cache = self.forward_text(text)
        return self.cache

    # 前向传播函数
    def forward(self, text: List[List[str]]) -> Tensor:
        if self.training:
            return self.forward_text(text)
        else:
            return self.forward_cache(text)

    # 文本前向传播函数
    def forward_text(self, text: List[List[str]]) -> Tensor:
        # 计算每个批次的序列数量
        num_per_batch = [len(t) for t in text]
        assert max(num_per_batch) == min(num_per_batch), (
            'number of sequences not equal in batch')
        # 将文本列表展平
        text = list(itertools.chain(*text))
        # 根据训练状态选择文本嵌入字典
        if self.training:
            text_embed_dict = self.text_embed
        else:
            text_embed_dict = self.test_embed
        # 根据文本获取对应的嵌入向量
        text_embeds = torch.stack(
            [text_embed_dict[x.split("/")[0]] for x in text])
        # 设置梯度为False,转换为浮点型
        text_embeds = text_embeds.to(
            self.buff.device).requires_grad_(False).float()
        # 重塑嵌入向量形状
        text_embeds = text_embeds.reshape(-1, num_per_batch[0],
                                          text_embeds.shape[-1])
        return text_embeds


# 注册MultiModalYOLOBackbone类到MODELS模块
@MODELS.register_module()
class MultiModalYOLOBackbone(BaseModule):
    # 初始化函数,接受图像模型、文本模型、冻结阶段和初始化配置作为参数
    def __init__(self,
                 image_model: ConfigType,
                 text_model: ConfigType,
                 frozen_stages: int = -1,
                 init_cfg: OptMultiConfig = None) -> None:
        
        # 调用父类的初始化函数
        super().__init__(init_cfg)
        
        # 使用传入的配置构建图像模型和文本模型
        self.image_model = MODELS.build(image_model)
        self.text_model = MODELS.build(text_model)
        self.frozen_stages = frozen_stages
        # 冻结指定阶段的参数
        self._freeze_stages()

    # 冻结指定阶段的参数
    def _freeze_stages(self):
        """Freeze the parameters of the specified stage so that they are no
        longer updated."""
        if self.frozen_stages >= 0:
            for i in range(self.frozen_stages   1):
                # 获取指定阶段的模型层
                m = getattr(self.image_model, self.image_model.layers[i])
                # 将模型设置为评估模式
                m.eval()
                # 冻结模型参数
                for param in m.parameters():
                    param.requires_grad = False

    # 将模型转换为训练模式,同时保持归一化层冻结
    def train(self, mode: bool = True):
        """Convert the model into training mode while keep normalization layer
        frozen."""
        # 调用父类的训练函数
        super().train(mode)
        # 冻结指定阶段的参数
        self._freeze_stages()

    # 前向传播函数,接受图像和文本作为输入,返回图像特征和文本特征
    def forward(self, image: Tensor,
                text: List[List[str]]) -> Tuple[Tuple[Tensor], Tensor]:
        # 获取图像特征
        img_feats = self.image_model(image)
        # 获取文本特征
        txt_feats = self.text_model(text)
        # 返回图像特征和文本特征
        return img_feats, txt_feats

.YOLO-Worldyolo_worldmodelsbackbones__init__.py

代码语言:javascript复制
# 版权声明,版权归腾讯公司所有
# YOLO 多模态骨干网络(视觉语言)
# 视觉部分:YOLOv8 CSPDarknet
# 语言部分:CLIP 文本编码器(12层transformer)
# 导入多模态骨干网络相关模块
from .mm_backbone import (
    MultiModalYOLOBackbone,
    HuggingVisionBackbone,
    HuggingCLIPLanguageBackbone,
    PseudoLanguageBackbone)

# 导出的模块列表
__all__ = [
    'MultiModalYOLOBackbone',
    'HuggingVisionBackbone',
    'HuggingCLIPLanguageBackbone',
    'PseudoLanguageBackbone'
]

.YOLO-Worldyolo_worldmodelsdata_preprocessorsdata_preprocessor.py

代码语言:javascript复制
# 版权声明
# 导入必要的库和模块
from typing import Optional, Union
import torch
from mmdet.models.data_preprocessors import DetDataPreprocessor
from mmengine.structures import BaseDataElement
from mmyolo.registry import MODELS

# 定义数据类型
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, None]

# 注册YOLOWDetDataPreprocessor类到MODELS模块
@MODELS.register_module()
class YOLOWDetDataPreprocessor(DetDataPreprocessor):
    """Rewrite collate_fn to get faster training speed.

    Note: It must be used together with `mmyolo.datasets.utils.yolow_collate`
    """

    # 初始化函数,接受参数和关键字参数
    def __init__(self, *args, non_blocking: Optional[bool] = True, **kwargs):
        # 调用父类的初始化函数,并传入参数和关键字参数
        super().__init__(*args, non_blocking=non_blocking, **kwargs)
    # 执行基于“DetDataPreprocessor”的归一化、填充和bgr2rgb转换
    def forward(self, data: dict, training: bool = False) -> dict:
        """Perform normalization, padding and bgr2rgb conversion based on
        ``DetDataPreprocessorr``.

        Args:
            data (dict): Data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.

        Returns:
            dict: Data in the same format as the model input.
        """
        # 如果不是训练状态,则直接调用父类的forward方法
        if not training:
            return super().forward(data, training)

        # 对数据进行类型转换
        data = self.cast_data(data)
        inputs, data_samples = data['inputs'], data['data_samples']
        assert isinstance(data['data_samples'], dict)

        # TODO: 支持多尺度训练
        # 如果启用通道转换且输入数据通道数为3,则进行通道转换
        if self._channel_conversion and inputs.shape[1] == 3:
            inputs = inputs[:, [2, 1, 0], ...]
        # 如果启用归一化,则对输入数据进行归一化处理
        if self._enable_normalize:
            inputs = (inputs - self.mean) / self.std

        # 如果存在批量增强操作,则逐个应用
        if self.batch_augments is not None:
            for batch_aug in self.batch_augments:
                inputs, data_samples = batch_aug(inputs, data_samples)

        # 生成图像元信息列表
        img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
        data_samples_output = {
            'bboxes_labels': data_samples['bboxes_labels'],
            'texts': data_samples['texts'],
            'img_metas': img_metas
        }
        # 如果数据样本中包含'masks',则添加到输出中
        if 'masks' in data_samples:
            data_samples_output['masks'] = data_samples['masks']
        # 如果数据样本中包含'is_detection',则添加到输出中
        if 'is_detection' in data_samples:
            data_samples_output['is_detection'] = data_samples['is_detection']

        # 返回处理后的数据
        return {'inputs': inputs, 'data_samples': data_samples_output}

.YOLO-Worldyolo_worldmodelsdata_preprocessors__init__.py

代码语言:javascript复制
# 版权声明,版权归腾讯公司所有
# 导入 YOLOWDetDataPreprocessor 类
from .data_preprocessor import YOLOWDetDataPreprocessor

# 导出 YOLOWDetDataPreprocessor 类,供外部使用
__all__ = ['YOLOWDetDataPreprocessor']

.YOLO-Worldyolo_worldmodelsdense_headsyolo_world_head.py

代码语言:javascript复制
# 导入所需的库和模块
import math
import copy
from typing import List, Optional, Tuple, Union, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.model import BaseModule
from torch import Tensor

from mmengine.dist import get_dist_info
from mmengine.structures import InstanceData
from mmdet.structures import SampleList
from mmdet.utils import OptConfigType, InstanceList, OptInstanceList
from mmdet.models.utils import (
    multi_apply,
    unpack_gt_instances,
    filter_scores_and_topk)
from mmyolo.registry import MODELS
from mmyolo.models.dense_heads import YOLOv8HeadModule, YOLOv8Head
from mmyolo.models.utils import gt_instances_preprocess
from mmcv.cnn.bricks import build_norm_layer

# 注册模型类为MODELS
@MODELS.register_module()
class ContrastiveHead(BaseModule):
    """Contrastive Head for YOLO-World
    compute the region-text scores according to the
    similarity between image and text features
    Args:
        embed_dims (int): embed dim of text and image features
    """
    def __init__(self,
                 embed_dims: int,
                 init_cfg: OptConfigType = None) -> None:

        super().__init__(init_cfg=init_cfg)

        # 初始化偏置参数
        self.bias = nn.Parameter(torch.zeros([]))
        # 初始化logit_scale参数
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, x: Tensor, w: Tensor) -> Tensor:
        """Forward function of contrastive learning."""
        # 对输入x进行L2范数归一化
        x = F.normalize(x, dim=1, p=2)
        # 对输入w进行L2范数归一化
        w = F.normalize(w, dim=-1, p=2)
        # 使用torch.einsum计算张量乘积
        x = torch.einsum('bchw,bkc->bkhw', x, w)
        # 对结果乘以logit_scale的指数并加上偏置
        x = x * self.logit_scale.exp()   self.bias
        return x


@MODELS.register_module()
class BNContrastiveHead(BaseModule):
    """ Batch Norm Contrastive Head for YOLO-World
    using batch norm instead of l2-normalization
    Args:
        embed_dims (int): embed dim of text and image features
        norm_cfg (dict): normalization params
    """
    # 定义一个名为ContrastiveHead的类,继承自nn.Module类
    def __init__(self,
                 embed_dims: int,
                 norm_cfg: ConfigDict,
                 init_cfg: OptConfigType = None) -> None:
        # 调用父类的初始化方法
        super().__init__(init_cfg=init_cfg)
        # 根据norm_cfg中的参数构建规范化层
        self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
        # 初始化偏置参数为0
        self.bias = nn.Parameter(torch.zeros([]))
        # 初始化logit_scale参数为-1.0,用于稳定性
        self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))

    # 定义前向传播函数
    def forward(self, x: Tensor, w: Tensor) -> Tensor:
        """Forward function of contrastive learning."""
        # 对输入x进行规范化
        x = self.norm(x)
        # 对输入w进行L2范数规范化
        w = F.normalize(w, dim=-1, p=2)
        # 使用torch.einsum进行张量乘法操作
        x = torch.einsum('bchw,bkc->bkhw', x, w)
        # 对结果乘以logit_scale的指数,并加上偏置
        x = x * self.logit_scale.exp()   self.bias
        # 返回结果
        return x
# 注册 YOLO-World 的头部模块到模型注册表中
@MODELS.register_module()
class YOLOWorldHeadModule(YOLOv8HeadModule):
    """Head Module for YOLO-World

    Args:
        embed_dims (int): embed dim for text feautures and image features
        use_bn_head (bool): use batch normalization head
    """

    def __init__(self,
                 *args,
                 embed_dims: int,
                 use_bn_head: bool = False,
                 **kwargs) -> None:
        # 初始化头部模块的属性
        self.embed_dims = embed_dims
        self.use_bn_head = use_bn_head
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

    def init_weights(self, prior_prob=0.01):
        """Initialize the weight and bias of PPYOLOE head."""
        # 调用父类的初始化权重方法
        super().init_weights()
        # 针对每个类别预测器和类别对比器进行初始化
        for cls_pred, cls_contrast, stride in zip(self.cls_preds,
                                                  self.cls_contrasts,
                                                  self.featmap_strides):
            cls_pred[-1].bias.data[:] = 0.0  # 重置偏置
            # 如果类别对比器有偏置属性
            if hasattr(cls_contrast, 'bias'):
                # 使用常数初始化类别对比器的偏置
                nn.init.constant_(
                    cls_contrast.bias.data,
                    math.log(5 / self.num_classes / (640 / stride)**2))

    def forward(self, img_feats: Tuple[Tensor],
                txt_feats: Tensor) -> Tuple[List]:
        """Forward features from the upstream network."""
        # 确保图像特征的数量等于级别数量
        assert len(img_feats) == self.num_levels
        # 将文本特征复制到每个级别的文本特征列表中
        txt_feats = [txt_feats for _ in range(self.num_levels)]
        # 调用 multi_apply 方法进行前向传播
        return multi_apply(self.forward_single, img_feats, txt_feats,
                           self.cls_preds, self.reg_preds, self.cls_contrasts)
    def forward_single(self, img_feat: Tensor, txt_feat: Tensor,
                       cls_pred: nn.ModuleList, reg_pred: nn.ModuleList,
                       cls_contrast: nn.ModuleList) -> Tuple:
        """Forward feature of a single scale level."""
        # 获取输入特征的形状信息
        b, _, h, w = img_feat.shape
        # 使用分类预测模型对图像特征进行预测
        cls_embed = cls_pred(img_feat)
        # 使用对比损失模型对分类嵌入进行预测
        cls_logit = cls_contrast(cls_embed, txt_feat)
        # 使用回归预测模型对图像特征进行预测
        bbox_dist_preds = reg_pred(img_feat)
        # 如果回归最大值大于1
        if self.reg_max > 1:
            # 重新调整bbox_dist_preds的形状
            bbox_dist_preds = bbox_dist_preds.reshape(
                [-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)

            # TODO: get_flops脚本无法处理矩阵乘法的情况,稍后需要修复
            # 计算bbox_preds,softmax后与proj矩阵相乘
            bbox_preds = bbox_dist_preds.softmax(3).matmul(
                self.proj.view([-1, 1])).squeeze(-1)
            # 调整bbox_preds的形状
            bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
        else:
            bbox_preds = bbox_dist_preds
        # 如果是训练模式,返回分类预测、bbox预测和bbox距离预测
        if self.training:
            return cls_logit, bbox_preds, bbox_dist_preds
        # 如果是推理模式,返回分类预测和bbox预测
        else:
            return cls_logit, bbox_preds
@MODELS.register_module()
class YOLOWorldHead(YOLOv8Head):
    """注册YOLO-World头部模块,并继承自YOLOv8Head"""

    """YOLO-World头部"""
    def __init__(self, world_size=-1, *args, **kwargs) -> None:
        """初始化函数,设置world_size参数"""
        super().__init__(*args, **kwargs)
        self.world_size = world_size

    """YOLO World v8头部。"""
    def loss(self, img_feats: Tuple[Tensor], txt_feats: Tensor,
             batch_data_samples: Union[list, dict]) -> dict:
        """对上游网络的特征执行前向传播和损失计算"""

        outs = self(img_feats, txt_feats)
        # 快速版本
        loss_inputs = outs   (batch_data_samples['bboxes_labels'],
                              batch_data_samples['img_metas'])
        losses = self.loss_by_feat(*loss_inputs)

        return losses

    def loss_and_predict(
        self,
        img_feats: Tuple[Tensor],
        txt_feats: Tensor,
        batch_data_samples: SampleList,
        proposal_cfg: Optional[ConfigDict] = None
    ) -> Tuple[dict, InstanceList]:
        """执行头部的前向传播,然后从特征和数据样本中计算损失和预测。"""
        outputs = unpack_gt_instances(batch_data_samples)
        (batch_gt_instances, batch_gt_instances_ignore,
         batch_img_metas) = outputs

        outs = self(img_feats, txt_feats)

        loss_inputs = outs   (batch_gt_instances, batch_img_metas,
                              batch_gt_instances_ignore)
        losses = self.loss_by_feat(*loss_inputs)

        predictions = self.predict_by_feat(*outs,
                                           batch_img_metas=batch_img_metas,
                                           cfg=proposal_cfg)
        return losses, predictions

    def forward(self, img_feats: Tuple[Tensor],
                txt_feats: Tensor) -> Tuple[List]:
        """从上游网络前向传递特征。"""
        return self.head_module(img_feats, txt_feats)
    # 对象方法,用于对输入的图像特征、文本特征和批量数据样本进行前向传播,预测检测结果
    def predict(self,
                img_feats: Tuple[Tensor],
                txt_feats: Tensor,
                batch_data_samples: SampleList,
                rescale: bool = False) -> InstanceList:
        """Perform forward propagation of the detection head and predict
        detection results on the features of the upstream network.
        """
        # 从批量数据样本中提取图像元信息
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
        # 对输入的图像特征和文本特征进行前向传播
        outs = self(img_feats, txt_feats)
        # 根据前向传播的结果和图像元信息进行预测,返回预测结果
        predictions = self.predict_by_feat(*outs,
                                           batch_img_metas=batch_img_metas,
                                           rescale=rescale)
        # 返回预测结果
        return predictions

    # 对象方法,用于进行带有测试时间数据增强的测试
    def aug_test(self,
                 aug_batch_feats,
                 aug_batch_img_metas,
                 rescale=False,
                 with_ori_nms=False,
                 **kwargs):
        """Test function with test time augmentation."""
        # 抛出未实现的错误,提示该方法尚未实现
        raise NotImplementedError('aug_test is not implemented yet.')

.YOLO-Worldyolo_worldmodelsdense_headsyolo_world_seg_head.py

代码语言:javascript复制
# 版权声明
# 导入数学库
import math
# 导入类型提示相关库
from typing import List, Optional, Tuple, Union, Sequence

# 导入 PyTorch 库
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm

# 导入 mmcv 库中的模块
from mmcv.cnn import ConvModule
from mmengine.config import ConfigDict
from mmengine.dist import get_dist_info
from mmengine.structures import InstanceData
from mmdet.structures import SampleList
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
                         OptMultiConfig, InstanceList)
from mmdet.models.utils import multi_apply, unpack_gt_instances
from mmyolo.models.dense_heads import YOLOv8HeadModule
from mmyolo.models.utils import gt_instances_preprocess
from mmyolo.registry import MODELS, TASK_UTILS
from mmyolo.models.dense_heads.yolov5_ins_head import (
    ProtoModule, YOLOv5InsHead
)

# 导入自定义的模块
from .yolo_world_head import ContrastiveHead, BNContrastiveHead

# 注册 YOLOWorldSegHeadModule 类为模型
@MODELS.register_module()
class YOLOWorldSegHeadModule(YOLOv8HeadModule):
    # 初始化方法
    def __init__(self,
                 *args,
                 embed_dims: int,
                 proto_channels: int,
                 mask_channels: int,
                 freeze_bbox: bool = False,
                 use_bn_head: bool = False,
                 **kwargs) -> None:
        # 初始化属性
        self.freeze_bbox = freeze_bbox
        self.embed_dims = embed_dims
        self.proto_channels = proto_channels
        self.mask_channels = mask_channels
        self.use_bn_head = use_bn_head
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)
    def init_weights(self, prior_prob=0.01):
        """初始化PPYOLOE头部的权重和偏置。"""
        # 调用父类的初始化权重方法
        super().init_weights()
        # 遍历分类预测、分类对比和特征图步长,分别初始化偏置
        for cls_pred, cls_contrast, stride in zip(self.cls_preds,
                                                  self.cls_contrasts,
                                                  self.featmap_strides):
            cls_pred[-1].bias.data[:] = 0.0  # 重置偏置
            # 如果分类对比具有偏置属性,则初始化为特定值
            if hasattr(cls_contrast, 'bias'):
                nn.init.constant_(
                    cls_contrast.bias.data,
                    math.log(5 / self.num_classes / (640 / stride)**2))

    def head_norm_eval(self):
        # 遍历分类预测模块,将所有批归一化层设置为评估模式
        for m in self.cls_preds:
            for q in m.modules():
                if isinstance(q, _BatchNorm):
                    q.eval()

        # 遍历回归预测模块,将所有批归一化层设置为评估模式
        for m in self.reg_preds:
            for q in m.modules():
                if isinstance(q, _BatchNorm):
                    q.eval()

    def train(self, mode: bool = True):
        """将模型转换为训练模式,同时保持归一化层冻结。"""
        # 调用父类的训练方法
        super().train(mode)
        # 如果冻结边界框,则调用头部归一化评估方法
        if self.freeze_bbox:
            self.head_norm_eval()

    def forward(self, img_feats: Tuple[Tensor],
                txt_feats: Tensor) -> Tuple[List]:
        """从上游网络前向传播特征。"""
        # 断言图像特征的长度等于级别数
        assert len(img_feats) == self.num_levels
        # 将文本特征复制多份以匹配级别数
        txt_feats = [txt_feats for _ in range(self.num_levels)]
        # 生成掩码原型
        mask_protos = self.proto_pred(img_feats[0])
        # 多路并行处理,获取分类logit、边界框预测、边界框距离预测和系数预测
        cls_logit, bbox_preds, bbox_dist_preds, coeff_preds = multi_apply(
            self.forward_single, img_feats, txt_feats, self.cls_preds,
            self.reg_preds, self.cls_contrasts, self.seg_preds)
        # 如果处于训练模式,则返回所有预测结果和掩码原型
        if self.training:
            return cls_logit, bbox_preds, bbox_dist_preds, coeff_preds, mask_protos
        # 否则,返回分类logit、边界框预测、系数预测和掩码原型
        else:
            return cls_logit, bbox_preds, None, coeff_preds, mask_protos
    def forward_single(self, img_feat: Tensor, txt_feat: Tensor,
                       cls_pred: nn.ModuleList, reg_pred: nn.ModuleList,
                       cls_contrast: nn.ModuleList,
                       seg_pred: nn.ModuleList) -> Tuple:
        """Forward feature of a single scale level."""
        # 获取输入特征的形状信息
        b, _, h, w = img_feat.shape
        # 使用分类预测模型对图像特征进行预测
        cls_embed = cls_pred(img_feat)
        # 使用对比损失模型对分类嵌入进行预测
        cls_logit = cls_contrast(cls_embed, txt_feat)
        # 使用回归预测模型对图像特征进行预测
        bbox_dist_preds = reg_pred(img_feat)
        # 使用分割预测模型对图像特征进行预测
        coeff_pred = seg_pred(img_feat)
        # 如果回归最大值大于1
        if self.reg_max > 1:
            # 重塑回归预测结果的形状
            bbox_dist_preds = bbox_dist_preds.reshape(
                [-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)

            # TODO: get_flops脚本无法处理矩阵乘法的情况,稍后需要修复
            # 计算边界框预测结果
            bbox_preds = bbox_dist_preds.softmax(3).matmul(
                self.proj.view([-1, 1])).squeeze(-1)
            bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
        else:
            bbox_preds = bbox_dist_preds
        # 如果处于训练模式
        if self.training:
            return cls_logit, bbox_preds, bbox_dist_preds, coeff_pred
        else:
            return cls_logit, bbox_preds, None, coeff_pred
# 注册 YOLO World Segmentation Head 类到 MODELS 模块
@MODELS.register_module()
class YOLOWorldSegHead(YOLOv5InsHead):
    # 特殊初始化函数,用于处理不同算法的特殊初始化过程
    def special_init(self):
        """Since YOLO series algorithms will inherit from YOLOv5Head, but
        different algorithms have special initialization process.

        The special_init function is designed to deal with this situation.
        """
        # 如果存在训练配置,则构建分配器
        if self.train_cfg:
            self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
            # 添加常用属性以减少计算
            self.featmap_sizes_train = None
            self.num_level_priors = None
            self.flatten_priors_train = None
            self.stride_tensor = None

    """YOLO World head."""

    # 损失函数,计算前向传播和检测头特征的损失
    def loss(self, img_feats: Tuple[Tensor], txt_feats: Tensor,
             batch_data_samples: Union[list, dict]) -> dict:
        """Perform forward propagation and loss calculation of the detection
        head on the features of the upstream network."""

        # 执行前向传播并获取输出
        outs = self(img_feats, txt_feats)
        # 快速版本
        loss_inputs = outs   (batch_data_samples['bboxes_labels'],
                              batch_data_samples['masks'],
                              batch_data_samples['img_metas'])
        # 计算损失
        losses = self.loss_by_feat(*loss_inputs)

        return losses

    # 损失和预测函数
    def loss_and_predict(
        self,
        img_feats: Tuple[Tensor],
        txt_feats: Tensor,
        batch_data_samples: SampleList,
        proposal_cfg: Optional[ConfigDict] = None
    def forward(self, img_feats: Tuple[Tensor],
                txt_feats: Tensor) -> Tuple[List]:
        """Forward features from the upstream network."""
        # 从上游网络中前向传播特征
        return self.head_module(img_feats, txt_feats)

    def predict(self,
                img_feats: Tuple[Tensor],
                txt_feats: Tensor,
                batch_data_samples: SampleList,
                rescale: bool = False) -> InstanceList:
        """Perform forward propagation of the detection head and predict
        detection results on the features of the upstream network.
        """
        # 从检测头部进行前向传播,并在上游网络的特征上预测检测结果
        # 获取批量数据样本的元信息
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
        # 获取模型输出
        outs = self(img_feats, txt_feats)
        # 根据模型输出进行预测
        predictions = self.predict_by_feat(*outs,
                                           batch_img_metas=batch_img_metas,
                                           rescale=rescale)
        return predictions

    def forward(self, img_feats: Tuple[Tensor],
                txt_feats: Tensor) -> Tuple[dict, InstanceList]:
        """Perform forward propagation of the head, then calculate loss and
        predictions from the features and data samples.
        """
        # 解包批量数据样本
        outputs = unpack_gt_instances(batch_data_samples)
        (batch_gt_instances, batch_gt_instances_ignore,
         batch_img_metas) = outputs

        # 获取模型输出
        outs = self(img_feats, txt_feats)

        # 构建损失函数输入
        loss_inputs = outs   (batch_gt_instances, batch_img_metas,
                              batch_gt_instances_ignore)
        # 计算损失
        losses = self.loss_by_feat(*loss_inputs)

        # 根据模型输出进行预测
        predictions = self.predict_by_feat(*outs,
                                           batch_img_metas=batch_img_metas,
                                           cfg=proposal_cfg)
        return losses, predictions
    # 定义一个测试函数,用于测试时进行数据增强
    def aug_test(self,
                 aug_batch_feats,
                 aug_batch_img_metas,
                 rescale=False,
                 with_ori_nms=False,
                 **kwargs):
        """Test function with test time augmentation."""
        # 抛出未实现错误,提示该函数还未被实现
        raise NotImplementedError('aug_test is not implemented yet.')

.YOLO-Worldyolo_worldmodelsdense_heads__init__.py

代码语言:javascript复制
# 导入 YOLOWorldHead 和 YOLOWorldHeadModule 类
from .yolo_world_head import YOLOWorldHead, YOLOWorldHeadModule
# 导入 YOLOWorldSegHead 和 YOLOWorldSegHeadModule 类
from .yolo_world_seg_head import YOLOWorldSegHead, YOLOWorldSegHeadModule

# 定义 __all__ 列表,包含需要导出的类名
__all__ = [
    'YOLOWorldHead', 'YOLOWorldHeadModule', 'YOLOWorldSegHead',
    'YOLOWorldSegHeadModule'
]

.YOLO-Worldyolo_worldmodelsdetectorsyolo_world.py

代码语言:javascript复制
# 导入所需的模块和类
from typing import List, Tuple, Union
from torch import Tensor
from mmdet.structures import OptSampleList, SampleList
from mmyolo.models.detectors import YOLODetector
from mmyolo.registry import MODELS

# 注册YOLOWorldDetector类到MODELS模块
@MODELS.register_module()
class YOLOWorldDetector(YOLODetector):
    """Implementation of YOLOW Series"""
    # 初始化函数,接受一些参数
    def __init__(self,
                 *args,
                 mm_neck: bool = False,
                 num_train_classes=80,
                 num_test_classes=80,
                 **kwargs) -> None:
        # 初始化类的属性
        self.mm_neck = mm_neck
        self.num_train_classes = num_train_classes
        self.num_test_classes = num_test_classes
        # 调用父类的初始化函数
        super().__init__(*args, **kwargs)

    # 计算损失函数的方法,接受输入和数据样本
    def loss(self, batch_inputs: Tensor,
             batch_data_samples: SampleList) -> Union[dict, list]:
        """Calculate losses from a batch of inputs and data samples."""
        # 设置bbox_head的类别数为训练类别数
        self.bbox_head.num_classes = self.num_train_classes
        # 提取图像特征和文本特征
        img_feats, txt_feats = self.extract_feat(batch_inputs,
                                                 batch_data_samples)
        # 计算损失
        losses = self.bbox_head.loss(img_feats, txt_feats, batch_data_samples)
        # 返回损失
        return losses
    # 预测模型的方法,接受批量输入和数据样本,返回带有后处理的结果列表
    def predict(self,
                batch_inputs: Tensor,
                batch_data_samples: SampleList,
                rescale: bool = True) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.
        """

        # 提取图像特征和文本特征
        img_feats, txt_feats = self.extract_feat(batch_inputs,
                                                 batch_data_samples)

        # 设置边界框头部的类别数为文本特征的第一个维度大小
        self.bbox_head.num_classes = txt_feats[0].shape[0]
        
        # 使用图像特征、文本特征和数据样本进行预测,返回结果列表
        results_list = self.bbox_head.predict(img_feats,
                                              txt_feats,
                                              batch_data_samples,
                                              rescale=rescale)

        # 将预测结果添加到数据样本中
        batch_data_samples = self.add_pred_to_datasample(
            batch_data_samples, results_list)
        
        # 返回更新后的数据样本
        return batch_data_samples

    # 网络前向传播过程,通常包括骨干网络、颈部和头部的前向传播,不包含任何后处理
    def _forward(
            self,
            batch_inputs: Tensor,
            batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
        """Network forward process. Usually includes backbone, neck and head
        forward without any post-processing.
        """
        
        # 提取图像特征和文本特征
        img_feats, txt_feats = self.extract_feat(batch_inputs,
                                                 batch_data_samples)
        
        # 进行边界框头部的前向传播,返回结果
        results = self.bbox_head.forward(img_feats, txt_feats)
        
        # 返回结果
        return results
    # 定义一个方法用于提取特征,接受两个输入参数:batch_inputs(张量)和batch_data_samples(样本列表),返回一个元组
    def extract_feat(
            self, batch_inputs: Tensor,
            batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:
        """Extract features."""
        # 如果batch_data_samples是字典类型,则获取其中的'texts'键对应的值
        if isinstance(batch_data_samples, dict):
            texts = batch_data_samples['texts']
        # 如果batch_data_samples是列表类型,则遍历其中的数据样本,获取每个数据样本的文本信息
        elif isinstance(batch_data_samples, list):
            texts = [data_sample.texts for data_sample in batch_data_samples]
        # 如果batch_data_samples既不是字典类型也不是列表类型,则抛出类型错误异常
        else:
            raise TypeError('batch_data_samples should be dict or list.')

        # 调用backbone模型提取图像和文本特征
        img_feats, txt_feats = self.backbone(batch_inputs, texts)
        # 如果模型包含neck部分
        if self.with_neck:
            # 如果使用多模态neck
            if self.mm_neck:
                # 将图像特征和文本特征输入到neck模块中进行处理
                img_feats = self.neck(img_feats, txt_feats)
            else:
                # 只将图像特征输入到neck模块中进行处理
                img_feats = self.neck(img_feats)
        # 返回提取的图像特征和文本特征
        return img_feats, txt_feats

0 人点赞