音视频开发之旅(92)-多模态Clip论文解读与源码分析

2024-09-07 07:59:53 浏览数 (3)

目录

1. 背景和问题

2. CLIP模型结构

3. 实验效果

4. 源码分析

5. CLIP的局限性和不足

6. 资料

一. 背景和问题

在做分类 检测以及分割任务时,数据的标注非常关键, 比如可用于分类任务的ImageNet数据集共有120万张图片1000个分类,  可用于目标检测和分割任务的COCO数据集共有33万张图片80个目标类别. 传统的图像分类模型通常在标注的数据集上进行训练,但这些数据集的类别和数量相对较小,训练的模型泛化能力也受限,很难直接zero-shot迁移到下游任务.

Transformer在NLP领域大放光彩,在CV领域基于Transformer的VIT等取得了不错的效果, 但这两个领域之间的交互是一个挑战,Clip就是研究这个问题,今天我们开启多模态的学习

由于文章周涉及到不少名词,为了更好的理解,先对其进行解释:

  • Linear-probe: 用于衡量特征提取器性能的一种方法,通过冻结网络的backbone,只对最后一层Fully Connected Layer(全连接层)进行训练,可以更准确的反映预训练模型的好坏.
  • distribution gap: 不同数据集分布上存在一定的差距,导致准确率或者泛化表现差,  例如出现out of distribution(推理的数据和预训练的数据来自不同分布)的情况,这个在画质评测任务中也是经常遇到,eg: 训练数据集大部分来自用户拍摄的白天的图像,那么对于合成的纯色背景加文字或者黑夜场景 推理评测结果就不太好.
  • zero-shot learning: 零样本学习,它是指在没有直接训练数据的情况下,使模型能够识别或者预测新的/未见过的类别. 如下图经典的"斑马案例":假设模型已经能够识别马,老虎和熊猫,现在需要该模型也识别斑马,zero-shot就是不通过训练给模型见斑马的图片,而是在推理时告诉模型斑马有什么特征,模型也可以成功识别出斑马
图片图片

图片来自:Zero-shot, One-shot和Few-shot的理解

二. CLIP模型结构

CLIP(Contrastive Language-Image Pre-training)是由OpenAI在2021年发布的一种多模态训练的神经网络,采用了对比学习的思想, 对收集的4亿张图文对进行预训练. 通过图文Embedding相似度来实现分类,打破了之前固定标签的范式. 无论是在手机数据集还是模型训练,都不需要像ImageNet-1000那样做分类,直接手机文字-图像对,然后用无监督的方式进行预测相似性.

模型训练: 每一张图像都有一小句解释性文字,将文字和图片分别通过一个编码器,得到向量表示, 对角线为正样本,非对角线为负样本,然后计算余弦相似度, 整体上采用双塔模型:图像塔和文本塔。图像塔负责提取图像表征,一般为Vision Transformer, 文本塔则负责提取文本特征,使用经典Transformer架构。

模型推理: clip推理过程不依赖传统的分类层,而是直接通过图像和文本Embedding之间的相似度来实现分类

图片图片

Clip只开源了推理代码和预训练模型,论文中提供了下面的训练伪代码

图片图片

可以看出和上面的模型架构一致:

  • 首先对Image和Text分别通过图像和文本编码器进行特征提取
  • 然后把图像和文本的特征向量经过投影矩阵W_i和W_t,映射到相同维度的潜在空间,然后进行归一化,得到图像和文本的Embedding表示
  • 接着计算图像Embedding和文本Embedding的余弦相似度,并通过temperature(温度)参数进行缩放
  • 最后分别计算图像到文本和文本到图像的交叉熵,取两者均值作为最终的loss

三. 实验结果

作者在30个数据集上,对zero-shot的Clip和Linear probe的ResNet50进行对比,可以clip可以达到和ResNet50在特定的标注好的数据集上训练后的模型水平相当

图片图片

zero-shot Clip 的泛化能力

下图使用在ImnageNet数据集预训练的RestNet101和Zero-shot的Clip进行对比, 在ImageNet数据集上准确率都为76.2%,表现相当,但是迁移到其他数据集上,Zero-shot Clip明显更优,体现了其更好的泛化性和撸棒性.

图片图片

除了分类任务外,CLIP模型已经在许多视觉和语言任务中展现出很好的性能,图像分类、零样本分类、语义分割、图像生成的指导、图像问答

四. 源码解析

4.1 demo

输入一张图片, 多个文本label, 预测图片为每个label的概率.

  • 首先对图像进行resize,crop归一化等预处理到模型需要的shape:torch.Size([1, 3, 224, 224]);对text进行SimpleToken转为token,一个英文单词对应一个token(后面会有详细示例说明)
  • 然后分别对image和text进行特征提取,其中 image使用VIT作为backbone,text使用TransformerEncoder作为backbone.
  • 最后 经过softmax输出 图片为每个label的概率
代码语言:javascript复制
import numpy as npimport pytestimport torchfrom PIL import Imageimport clip
def test(model_name="ViT-B/32"):    device = "cuda" if torch.cuda.is_available() else "cpu"    model, preprocess = clip.load(model_name, device=device)
    image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)#对图片进行resize crop 转为张量 归一化处理  -->输入:image mode=RGBA size=2162x762; 输出:torch.Size([1, 3, 224, 224])    text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
    with torch.no_grad():        image_features = model.encode_image(image) #对image通过VIT进行特征提取. 输入:torch.Size([1, 3, 224, 224]) 图像的tensor数据,输出:torch.Size([1, 512])        text_features = model.encode_text(text) #对text通过Transformer进行特征提取.输入torch.Size([3, 77]) 对应["a diagram", "a dog", "a cat"]词的tokens,输出:torch.Size([3, 512])                logits_per_image, logits_per_text = model(image, text)        probs = logits_per_image.softmax(dim=-1).cpu().numpy()#经过softmax输出 图片为每个label的概率
    print("Label probs:", probs)     if __name__ == "__main__":    #clip.available_models:['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']    print(f"clip.available_models:{clip.available_models()}")    test()

4.2 文字转为token: clip.tokenize

代码语言:javascript复制
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:    """    返回给定输入字符串的tokens    """    if isinstance(texts, str):        texts = [texts]
    sot_token = _tokenizer.encoder["<|startoftext|>"]#49406    eot_token = _tokenizer.encoder["<|endoftext|>"]  #49407    all_tokens = [[sot_token]   _tokenizer.encode(text)   [eot_token] for text in texts] #加上开头和接口的token [[49406, 320, 22697, 49407], [49406, 320, 1929, 49407], [49406, 320, 2368, 49407]]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) #
    for i, tokens in enumerate(all_tokens):        if len(tokens) > context_length:#context_length:77,如果tokens的长度大于context_length,做截断处理或者抛异常            if truncate:                tokens = tokens[:context_length]                tokens[-1] = eot_token            else:                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")        result[i, :len(tokens)] = torch.tensor(tokens)#转为张量赋值给result
    return result #torch.Size([3, 77]) ,['a diagram', 'a dog', 'a cat']的tokens

['a diagram', 'a dog', 'a cat']的tokens shape为 torch.Size([3, 77]) ,具体内容如下图,其中49406是每个tokens的startToken,49407是每个tokens的endToken. 可以看出基本一个英文单词对应一个token

图片图片

4.3 图像预处理 preprocess

代码语言:javascript复制
def _transform(n_px):    return Compose([        Resize(n_px, interpolation=BICUBIC), #默认3*224*224        CenterCrop(n_px),        _convert_image_to_rgb,        ToTensor(),        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), #使用ImageNet的均值方差进行归一化    ])
图片图片

4.4  构造clip模型

代码语言:javascript复制
class CLIP(nn.Module):    def __init__(self,                 embed_dim: int,#512                 # vision                 image_resolution: int,#224                 vision_layers: Union[Tuple[int, int, int, int], int], #12                 vision_width: int,#768                 vision_patch_size: int,#32                 # text                 context_length: int,#77                 vocab_size: int,#49408                 transformer_width: int,#512                 transformer_heads: int,#8                 transformer_layers: int #12                 ):        super().__init__()
        self.context_length = context_length
        if isinstance(vision_layers, (tuple, list)):            vision_heads = vision_width * 32 // 64            self.visual = ModifiedResNet(                layers=vision_layers,                output_dim=embed_dim,                heads=vision_heads,                input_resolution=image_resolution,                width=vision_width            )        else:            vision_heads = vision_width // 64 #768//64=12            self.visual = VisionTransformer( #定义用于Image特征提取的Transformer                input_resolution=image_resolution, #输入图像分辨率224*224                patch_size=vision_patch_size, #每个patch的大小32*32                width=vision_width, #768,这个vision_width是什么?                layers=vision_layers, #12个layer                heads=vision_heads, #multi-headattention 8个头                output_dim=embed_dim #输出维度 512            )
        self.transformer = Transformer(            width=transformer_width,#512            layers=transformer_layers,#12            heads=transformer_heads,#8            attn_mask=self.build_attention_mask()        )
        self.vocab_size = vocab_size #词库大小 49408        self.token_embedding = nn.Embedding(vocab_size, transformer_width) #transformer_width:512        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))        self.ln_final = LayerNorm(transformer_width)
        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.initialize_parameters()
    def initialize_parameters(self):        nn.init.normal_(self.token_embedding.weight, std=0.02) #将文本token的embedding权重初始为均值为0,标准差为0.02的正态分布        nn.init.normal_(self.positional_embedding, std=0.01)   #将positional_embedding权重初始为均值为0,标准差为0.01的正态分布
        if isinstance(self.visual, ModifiedResNet):            if self.visual.attnpool is not None:                std = self.visual.attnpool.c_proj.in_features ** -0.5                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:                for name, param in resnet_block.named_parameters():                    if name.endswith("bn3.weight"):                        nn.init.zeros_(param)
        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)        attn_std = self.transformer.width ** -0.5        fc_std = (2 * self.transformer.width) ** -0.5        for block in self.transformer.resblocks:            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
        if self.text_projection is not None:            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
    def build_attention_mask(self):        # lazily create causal attention mask, with full attention between the vision tokens        # pytorch uses additive attention mask; fill with -inf        mask = torch.empty(self.context_length, self.context_length)        mask.fill_(float("-inf")) #全部填充为负无穷大        mask.triu_(1)  # zero out the lower diagonal,把下三角设置为0.进行softmax时softmax(-inf)为0 起到了mask作用        return mask
    @property    def dtype(self):        return self.visual.conv1.weight.dtype
    def encode_image(self, image):        return self.visual(image.type(self.dtype)) #self.dtype:torch.float16
    def encode_text(self, text):        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model] ,输入torch.Size([3, 77]),输出 torch.Size([3, 77, 512])
        x = x   self.positional_embedding.type(self.dtype)#self.dtype:torch.float16 加上位置编码 ,输出还是torch.Size([3, 77, 512])        x = x.permute(1, 0, 2)  # NLD -> LND #输出 torch.Size([77, 3, 512])        x = self.transformer(x) #进行transormerEncoder(由多层MultiHeadAttention和MLP组成)特征提取,输出和输入shape一致.torch.Size([77, 3, 512])        x = x.permute(1, 0, 2)  # LND -> NLD 输出torch.Size([3, 77, 512])        x = self.ln_final(x).type(self.dtype) #进行layerNorm归一化
        # x.shape = [batch_size, n_ctx, transformer.width]        # take features from the eot embedding (eot_token is the highest number in each sequence)        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection #text.shape为torch.Size([3, 77]), self.text_projection为torch.Size([512, 512])
        return x
    def forward(self, image, text):        image_features = self.encode_image(image)        text_features = self.encode_text(text)
        # normalized features 特征归一化处理        image_features = image_features / image_features.norm(dim=1, keepdim=True)        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        # cosine similarity as logits        logit_scale = self.logit_scale.exp() #余弦相似度        logits_per_image = logit_scale * image_features @ text_features.t()        logits_per_text = logits_per_image.t()
        # shape = [global_batch_size, global_batch_size]        return logits_per_image, logits_per_text

4.5 图像特征提取 VisionTransformer

代码语言:javascript复制
class VisionTransformer(nn.Module):    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):        super().__init__() #input_resolution:224; patch_size:32; width:768; layers:12; heads:12; output_dim:512        self.input_resolution = input_resolution #224        self.output_dim = output_dim #512        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) #使用CNN进行特征提取作为Embedding
        scale = width ** -0.5 #with的平分根 分之一        self.class_embedding = nn.Parameter(scale * torch.randn(width)) #随机生成一个分类embedding        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2   1, width)) #随机初始化PE        self.ln_pre = LayerNorm(width)
        self.transformer = Transformer(width, layers, heads)
        self.ln_post = LayerNorm(width)        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
    def forward(self, x: torch.Tensor):        x = self.conv1(x)  # shape = [*, width, grid, grid] 输入:torch.Size([1, 3, 224, 224]),输出torch.Size([1, 768, 7, 7]) 一张224*224的图片横纵都切分为7分,每个patch的wh为224/7=32, 768为维度数量        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2] 输出torch.Size([1, 768, 49])        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width] #输出torch.Size([1, 49, 768])        x = torch.cat([self.class_embedding.to(x.dtype)   torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2   1, width] #在patchEmbdding前加入一个classToken,输出:torch.Size([1, 50, 768])        x = x   self.positional_embedding.to(x.dtype) #在PatchEmbdding后加上PositionEmbedding ,输出还是torch.Size([1, 50, 768])        x = self.ln_pre(x) #进行LayerNorm归一化
        x = x.permute(1, 0, 2)  # NLD -> LND ,输出torch.Size([50, 1, 768])        x = self.transformer(x) #进行VIT特征提取,输出和输入的shape一致, 还是torch.Size([50, 1, 768])        x = x.permute(1, 0, 2)  # LND -> NLD,输出torch.Size([1, 50, 768])
        x = self.ln_post(x[:, 0, :]) #输出torch.Size([1, 768]),保留第一维的dim
        if self.proj is not None:#self.proj.shape为torch.Size([768, 512])            x = x @ self.proj #输出torch.Size([1, 512])
        return x

4.6 文本特征提取Transformer

代码语言:javascript复制
class ResidualAttentionBlock(nn.Module):    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_head) #d_model:512, n_head:8, d_head=d_model/n_head=64        self.ln_1 = LayerNorm(d_model)        self.mlp = nn.Sequential(OrderedDict([            ("c_fc", nn.Linear(d_model, d_model * 4)),            ("gelu", QuickGELU()),            ("c_proj", nn.Linear(d_model * 4, d_model))        ]))        self.ln_2 = LayerNorm(d_model)        self.attn_mask = attn_mask
    def attention(self, x: torch.Tensor):        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
    def forward(self, x: torch.Tensor):        x = x   self.attention(self.ln_1(x))        x = x   self.mlp(self.ln_2(x))        return x

class Transformer(nn.Module):    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):        super().__init__()        self.width = width #768        self.layers = layers #12        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) #定义12层AttentBlock
    def forward(self, x: torch.Tensor):        return self.resblocks(x)

五. CLIP的局限性和不足

1. 虽然在很多数据集上zero-shot Clip和ResNet50表现相当,但是对应的任务上ResNet50表现并不是最优,Clip与那些SOTA的相比还是有不少差距,如果按照大模型 大数据训练成本和效果的范式进行预估,至少是现有Clip训练成本的1000倍

2. 在一些细分类数据集(eg:医疗)clip的准确率低于Resnet50

3. 在一些抽象的复杂的任务上,clip泛化比较差,eg:区分视频中某一帧是否异常

4. 如果推理数据和训练数据相差甚远(out of distribution),clip泛化也比较差,eg:在手写数字的数据集

5. 虽然clip可以做zero-shot,但是还是从给动的图-文对中进行相似度计算来选择,相比而言,生成式会更加灵活

六. 资料

1.论文:https://arxiv.org/pdf/2103.00020

2.源码:https://github.com/openai/CLIP

3.李沐-CLIP 论文逐段精读 https://www.bilibili.com/video/BV1SL4y1s7LQ

4.多模态模型学习1——CLIP对比学习 语言-图像预训练模型https://blog.csdn.net/weixin_44791964/article/details/129941386

5.多模态表征—CLIP及中文版Chinese-CLIP:理论讲解、代码微调与论文阅读 https://blog.csdn.net/weixin_44362044/article/details/136262247

6.openai多模态大模型:clip详解及实战 https://blog.csdn.net/lsb2002/article/details/132275132

7.深度学习系列37:CLIP模型https://blog.csdn.net/kittyzc/article/details/125167223

8.【代码实践】使用CLIP做一些多模态的事情https://blog.csdn.net/me_yundou/article/details/123236173

9.两个小时浅析CLIP模型,内含原理 代码复现 https://www.bilibili.com/video/BV1K1421U7jc/?vd_source=03a763fa6cf49b01f658f32592f5a6f3

10.一文读懂CLIP图文多模态模型 https://blog.csdn.net/weixin_47228643/article/details/136690837

11.多模态经典之作CLIP https://juejin.cn/post/7264503343996747830

12.李沐论文精读系列四:CLIP和改进工作串讲(LSeg、GroupViT、VLiD、 GLIPv1、 GLIPv2、CLIPasso)https://blog.csdn.net/qq_56591814/article/details/127421979

13.AI绘画原理解析:从CLIP、BLIP到DALLE、DALLE 2、DALLE 3、Stable Diffusion https://blog.csdn.net/v_JULY_v/article/details/131205615

14.图片来自:Zero-shot, One-shot和Few-shot的理解 https://blog.csdn.net/wzk4869/article/details/129419127

感谢你的阅读

接下来我们继续学习输出AI相关内容,欢迎关注公众号“音视频开发之旅”,一起学习成长。

欢迎交流

0 人点赞