当分类从固定类别走向开放类别!基于MMPreTrain实现Prompt-base分类丨开源之夏中选项目分享

2023-11-16 10:47:21 浏览数 (2)

「开源之夏 (OSPP)」是中科院软件所「开源软件供应链点亮计划」指导下的系列暑期活动,旨在鼓励在校学生积极参与开源软件的开发维护,培养和发掘更多优秀的开发者,促进优秀开源软件社区的蓬勃发展,助力开源软件供应链建设。

今年 OpenMMLab 首次参与开源之夏,并对外开放了 2 个项目课题。在参与者的努力付出和导师们的辛勤指导下,所有课题都已经如期顺利结项,为 OpenMMLab 开源社区注入了新的活力。

本次我们非常特别邀请到瞿博文同学,他在开源之夏 2023 中承担的项目是基于 MMPreTrain 实现 Prompt-base 分类器,以下是来自他的经验分享。

项目基本信息

项目名称:基于 MMPreTrain 实现 Prompt-base 分类器

项目导师:马泽润

项目需求:本题目的任务是实现一个 prompt-base 的分类器,它的权重是固定的,提供简单的接口, 给出以下参数就可以对任意图片进行分类:

  • 类别名(category)
  • 图像描述(optional)
  • 图片样例(optional)

项目背景与大致流程

项目背景

传统图像分类通常遵循预训练加微调(pretrain finetune)的模式,并依赖一个预设的固定类别表。然而,随着多种视觉-语言多模态模型(Vision-Language Models, VLMs)的兴起,这种多模态方法使得模型能够无需微调,仅通过预设提示(prompts)即可直接产出卓越的分类结果。这种做法颠覆了传统的预训练模型在图像分类下游任务中的微调方法,标志着从经典微调过渡到一种新的多模态范式——在这种范式中,模型不需要在下游任务上进行额外训练,而是直接依据具体任务构建相关的文本模板(prompt),通过多模态推理来得到分类结果。

大致流程

  1. 基于 OpenAI 的 CLIP 模型,利用其强大的 zero-shot 能力,实现 Open-Vocabulary 的图像分类(主要针对单目标分类,即仅有一个输出结果)
  2. 基于 RAM(Recognize Anything Model),实现 Open-Vocabulary 的多分类任务,可以将图像中所有物体进行识别并输出(即支持多目标分类)

关键概念

Registry 机制:

MM 系列库的核心,这一机制最初由 MMEngine 库定义。该机制为模型、数据集、优化器、学习率调度器、数据预处理转换、分词器等组件提供了一个注册表,注册表实现了字符串到具体类的映射。这意味着用户可以避免复杂的 import 语句,直接通过注册表快速访问并实例化所需的类。此外,Registry 机制还简化了配置文件(Config 文件)的编写过程,使得用户配置模型和实验变得更加高效和灵活。同时,也为模块测试提供了便利,对仓库的开发者和维护者来说是一个福音。

Hook 机制:

MM 系列库的又一个核心,可以在整个 pipeline 的某个部分,如:模型的 forward 途中,定义 Hook,从而为输出模型中间层特征,特征可视化等操作提供了便利。

各种基类:

MM 系列算法库提供了一系列的基类,例如 BaseModel、BaseDataProcessor 等。这些基类不仅明确规定了派生子类必须实现的方法,而且也便于子类继承和定制化重写。通过这种设计,MM 系列算法库的一致性和模块化得到了显著提升,同时也简化了新算法的集成和开发过程。

项目实现细节

基于 MMPreTrain 实现 CLIP

Step1:将 CLIP 的 ViT 转换成 MMPreTrain 中的 VisionTransformer 的实现

需要完成以下内容:

  • 完成 ViT 的 checkpoint 中的 state_dict 的转换
  • 实现 ViT-B/16 和 ViT-L/14 两种 setting 的转换

关键函数:

代码语言:javascript复制
from collections import OrderedDict

def convert_clip(ckpt):
    new_ckpt = OrderedDict()

    for k, v in list(ckpt.items()):
        new_v = v
        if k.startswith('visual.conv1'):
            new_k = k.replace('conv1', 'patch_embed.projection')
        elif k.startswith('visual.positional_embedding'):
            new_k = k.replace('positional_embedding', 'pos_embed')
            new_v = v.unsqueeze(dim=0)
        elif k.startswith('visual.class_embedding'):
            new_k = k.replace('class_embedding', 'cls_token')
            new_v = v.unsqueeze(dim=0).unsqueeze(dim=0)
        elif k.startswith('visual.ln_pre'):
            new_k = k.replace('ln_pre', 'pre_norm')
        elif k.startswith('visual.transformer.resblocks'):
            new_k = k.replace('transformer.resblocks', 'layers')
            if 'ln_1' in k:
                new_k = new_k.replace('ln_1', 'ln1')
            elif 'ln_2' in k:
                new_k = new_k.replace('ln_2', 'ln2')
            elif 'mlp.c_fc' in k:
                new_k = new_k.replace('mlp.c_fc', 'ffn.layers.0.0')
            elif 'mlp.c_proj' in k:
                new_k = new_k.replace('mlp.c_proj', 'ffn.layers.1')
            elif 'attn.in_proj_weight' in k:
                new_k = new_k.replace('in_proj_weight', 'qkv.weight')
            elif 'attn.in_proj_bias' in k:
                new_k = new_k.replace('in_proj_bias', 'qkv.bias')
            elif 'attn.out_proj' in k:
                new_k = new_k.replace('out_proj', 'proj')
        elif k.startswith('visual.ln_post'):
            new_k = k.replace('ln_post', 'ln1')
        elif k.startswith('visual.proj'):
            new_k = k.replace('visual.proj', 'visual_proj.proj')
        else:
            new_k = k

        new_ckpt[new_k] = new_v
    return new_ckpt

如此即可将 OpenAI 的 Vision Transformer 的权重转换到 MMPreTrain 内置实现的 Vision Transformer 的格式,方便我们在 MMPreTrain 框架下也可以加载 OpenAI 的 Vision Transformer 权重。

Step2:实现一个 CLIP 基类

其中需要完成以下功能:

  1. 模型结构组件的定义
  2. 实现图像处理、文本的处理以及 BBPE(Byte-level Byte Pair Encoding) 分词
  3. 实现图像特征的提取,以及文本特征的提取

核心代码:

代码语言:javascript复制
class CLIP(BaseModel):
    def __init__(self,
                 vision_backbone: dict,
                 projection: dict,
                 text_backbone: dict,
                 tokenizer: dict,
                 vocab_size: int,
                 transformer_width: int,
                 proj_dim: int,
                 context_length: int = 77,
                 data_preprocessor: Optional[dict] = None,
                 init_cfg: Optional[dict] = None):
        # 定义模型组件,包括图像、文本编码器,对齐所用的projection层、分词器tokenizer、
        # 对输出的logits进行scale的一个可训练常数logit_scale等
    def forward(
        self,
        images: torch.Tensor,
        data_samples: Optional[list] = None,
        mode: str = 'predict',
        **kwargs,
    ):
        # 仅支持推理,不支持训练
        if mode == 'predict':
            return self.predict(images, data_samples, **kwargs)
        else:
            raise RuntimeError(f'Invalid mode "{mode}".')

    def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor:
        """The function to extract image latent features."""

    def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor:
        """The function to extract text latent features."""

    def extract_feat(
            self, images: torch.Tensor,
            texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:

    def compute_similarity(self, images, texts):
        """Extract images and texts features and compute cosine similarity."""

    @abstractmethod
    def predict(self,
                images: torch.Tensor,
                data_samples: DataSample = None) -> DataSample:
        raise NotImplementedError

    def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor:
        """Returns the tokenized representation of given input string(s)
        Args:
            texts (Union[str, List[str]]): An input string or a list of input
                strings to tokenize
            context_length (int): The context length to use. Defaults to 52.
        Returns:
            torch.Tensor: Resulting tokens.
        """

可以看到,predict 方法暂未实现,需要在其子类中进行实现。

Step3:实现一个 CLIPZeroShot 类

它继承自 CLIP 基类,并实现额外的 zero-shot 推理功能,即任意给定一个 category,可以在这个 category 下实现 open-vocabulary 的分类。

具体而言,需要重写 CLIP 基类没有定义的 predict 方法,大致代码如下:

代码语言:javascript复制
@MODELS.register_module()
class CLIPZeroShot(CLIP):
    def predict(self,
                images: torch.Tensor,
                data_samples: DataSample = None) -> DataSample:
        if self.text_prototype_embeds is None:
            self.prepare_text_prototype(device=images.device)

        image_features = self.extract_image_feat(images=images)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logits_per_image = image_features @ self.text_prototype_embeds.to(
            image_features.device) * self.logit_scale.exp()

        pred_scores = F.softmax(logits_per_image, dim=1)
        pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()

        out_data_samples = []
        if data_samples is None:
            data_samples = [None for _ in range(pred_scores.size(0))]

        for data_sample, score, label in zip(data_samples, pred_scores,
                                             pred_labels):
            if data_sample is None:
                data_sample = DataSample()

            data_sample.set_pred_score(score).set_pred_label(label)
            out_data_samples.append(data_sample)
        return out_data_samples

    def prepare_text_prototype(self, device) -> None:
        """The function to prepare text prototypes with prompt."""
        class_embeddings = []
        for classname in track_on_main_process(self.prototype,
                                               'Prepare text prototype...'):
            # format with class
            texts = [prompt(classname) for prompt in self.prompt]
            tokenized_texts = self.tokenize(texts)
            class_features = self.extract_text_feat(tokenized_texts.to(device))
            class_features /= class_features.norm(dim=-1, keepdim=True)
            class_feature = class_features.mean(dim=0)
            class_feature /= class_feature.norm()
            class_embeddings.append(class_feature)
        self.text_prototype_embeds = torch.stack(
            class_embeddings, dim=1).to(device)

简单来说,即:CLIPZeroShot 类继承自 CLIP 基类,并在 predict 方法中实现了图像的 open-vocabulary 分类。

基于 MMPreTrain 实现 RAM

RAM 中需要使用 CLIP 模型的文本编码器提取文本特征,所以 RAM 的实现是基于上述的 MMPreTrain 中 CLIP 实现的。

Step1:将 RAM 的 SwinTranformer 转换成 MMPreTrain 中的实现

需要完成的功能:

  • SwinTransformer 的 checkpoint 中的 state_dict 的转换

在此过程中,我也遇到了一个困扰我很久的问题,即:

MMPetrain 中采用最新版本的 swin-transformer 实现,其中 PatchMerging 模块采用 nn.Unfold 实现,而其他 SwinTransformer 实现大多采用 Slice 再 Concat 的实现方式,所以在对应的 state_dict 的权重的通道顺序上也需要进行转换。

最初,我一直在硬磕这个问题,死磕了很久才得以解决,在和导师交流后,他很快地就在 MMSegmentation 库的 Issue 和 PR 中找到了几乎一模一样的问题和解决方案,这就是开源社区的好处,如果我早点意识到,也就会减少很多重复的工作量了~

关键函数代码:

代码语言:javascript复制
from collections import OrderedDict
def convert_swin(ckpt):
    new_ckpt = OrderedDict()
    convert_mapping = dict()

    def correct_unfold_reduction_order(x):
        out_channel, in_channel = x.shape
        x = x.reshape(out_channel, 4, in_channel // 4)
        x = x[:, [0, 2, 1, 3], :].transpose(1,
                                            2).reshape(out_channel, in_channel)
        return x

    def correct_unfold_norm_order(x):
        in_channel = x.shape[0]
        x = x.reshape(4, in_channel // 4)
        x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
        return x

    for k, v in ckpt.items():
        if 'attn_mask' in k:
            continue
        if k.startswith('head'):
            continue
        elif k.startswith('layers'):
            new_v = v
            if 'attn.' in k:
                new_k = k.replace('attn.', 'attn.w_msa.')
            elif 'mlp.' in k:
                if 'mlp.fc1.' in k:
                    new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
                elif 'mlp.fc2.' in k:
                    new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
                else:
                    new_k = k.replace('mlp.', 'ffn.')
            elif 'downsample' in k:
                new_k = k
                if 'reduction.' in k:
                    new_v = correct_unfold_reduction_order(v)
                elif 'norm.' in k:
                    new_v = correct_unfold_norm_order(v)
            else:
                new_k = k
            new_k = new_k.replace('layers', 'stages', 1)
        elif k.startswith('patch_embed'):
            new_v = v
            if 'proj' in k:
                new_k = k.replace('proj', 'projection')
            else:
                new_k = k
        elif k.startswith('norm'):
            new_v = v
            new_k = k.replace('norm', 'norm3')
        else:
            new_v = v
            new_k = k

        new_ckpt[new_k] = new_v
        convert_mapping[k] = new_k

    return new_ckpt, convert_mapping

Step2:实现 RAM 基类,并基于此实现默认词表的 RAMNormal 类和支持用户自定义词表的 RAMOpenset 类

此处的程序设计理念和 CLIPZeroShot 与 CLIP 基类类似,即 RAM 基类实现一些基本的模型推理和特征提取,而子类的 RAMNormal 和 RAMOpenset 更改其 predict 方法,以完成个性化的设计,大致的伪代码框架如下:

代码语言:javascript复制
class RAM(BaseModel):
    """The implementation of `RAM <https://arxiv.org/abs/2306.03514>`_."""

    def __init__(self,
                 tokenizer: dict,
                 vision_backbone: dict,
                 tag_encoder: dict,
                 tagging_head: dict,
                 text_decoder: dict,
                 device: str = 'cpu',
                 vision_width: int = 1536,
                 prompt='a picture of ',
                 threshold=0.68,
                 delete_tag_index=[],
                 tag_list='./data/ram_tag_list.pickle',
                 tag_list_chinese='./data/ram_tag_list_chinese.pickle',
                 data_preprocessor: Optional[dict] = None,
                 init_cfg: Optional[dict] = None):
        # 定义各组件

    def load_tag_list(self, tag_list_file):
        # 从文件中得到词表
    def get_label_embed(self):
        # 得到词表中每个词的嵌入特征
    def extract_visual_feature(self, images):
        # 提取视觉特征
    def image2tag(self, label_embed, image_embeds, image_atts):
        # image2tag推理
    def forward(
        self,
        images: torch.Tensor,
        data_samples: Optional[list] = None,
        mode: str = 'predict',
        **kwargs,
    ):
        if mode == 'predict':
            return self.predict(images, data_samples, **kwargs)
        else:
            raise RuntimeError(f'Invalid mode "{mode}".')
    @abstractmethod
    def predict(self,
                images: torch.Tensor,
                data_samples: DataSample = None) -> DataSample:
        raise NotImplementedError

@MODELS.register_module()
class RAMNormal(RAM):
    def tag_process(self, logits):
        # 处理词表
    def predict(self,
                    images: torch.Tensor,
                    data_samples: DataSample = None) -> DataSample:
        # 定义直接加载词表情况下的predict行为

@MODELS.register_module()
class RAMOpenset(RAMNormal):  # 继承RAMNormal类
    def set_openset(self,
                    categories: List[str] = None,
                    clip_ckpt: str = '',
                    threshold: float = 0.68):
        # openset的相关设置和embedding提取
    def tag_process(self, logits):
        # 重写tag_process函数

Step3:基于 gradio 实现一个 webui,能够让用户更便捷的使用 RAM

构建一 个WebUI,可以让用户更加方便地使用 RAM,测试其性能,并且近乎实时地看到输出结果,体感极强!

项目结果呈现

CLIP 在 CIFAR100 和 ImageNet1k 上的

zero-shot 性能对齐

如下表展示的数据所示,基于 MMPreTrain 实现的 CLIP 模型在 CIFAR100 和 ImageNet1k 这两个数据集上的 zero-shot 分类性能,可以与 OpenAI 的 CLIP 模型相媲美。

RAM 的 Gradio WebUI demo 展示

加载预设词表(Normal 模式):

使用自定义词表(Openset 模式,暂未支持中文输出):

PR 链接

CLIP:

https://github.com/open-mmlab/mmpretrain/pull/1737

RAM:

https://github.com/open-mmlab/mmpretrain/pull/1802

0 人点赞