音视频开发之旅(90)-Vision Transformer论文解读与源码分析

2024-09-07 07:58:31 浏览数 (3)

目录

1.背景和问题

2.Vision Transformer(VIT)模型结构

3.Patch Embedding

4.实现效果

5.代码解析

6.资料

一、背景和问题

上一篇我们学习了Transformer的原理,主要介绍了在NLP领域上的应用,那么在CV(图像视频)领域该如何使用?

最直观的想法就是把每一个像素像NLP中一个文字一样处理,理论上可行,但是这样做有什么不足吗?

Transformer的自注意力机制的计算复杂度是O(n^2),其中n是序列长度,一张720*1280的图片就需要921600个token,这将导致巨大的计算开销,使得模型的训练和推理非常缓慢。图像不同像素之间存在很多冗余信息(编码时会进行帧内压缩),是否可以采用类似编码压缩技术中的宏块方案呐(把图像分割为固定大小的16x16、8x8、4x4的的块)。

二、VIT模型结构

VIT的思路和视频编码的宏块思想类似,把图像分割为固定大小pathchs,然后通过线性变换得到patch embedding,将图像的patch embeddings送入transformer的Encoder进行特征提取,在根据不同任务添加不同的Head。ViT模型原理如下图所示:

图片图片

模型由三个模块组成:

  • Linear Projection of Flattened Patches(该网络的前处理,把图像分割为patch,然后进行Embedding)
  • Transformer Encoder(该网络的backbone,用于特征提取)
  • MLP Head(该网络的head,用于分类任务)

主要的公式如下:

图片图片
图片图片

可以看到VIT只用到了Transfomer的Encoder作为backbone进行特征提取,TransfomerEncoderLayer也是使用Multi-head Attention,不同的是LayerNormalation放在了Multi-head Attention的前面。和Transfromer的结构主要区别在于Embedding的过程,如果对于注意力机制还不太清楚,建议复习下上一篇

三、Patch Embedding

图片图片

关键点包括:

  1. 图像被分割成固定大小的patches。
  2. 每个patch通过线性投影映射到嵌入空间。
  3. 添加一个特殊的分类token。
  4. 加入位置编码以保留空间信息。

将2D图像转换为一个1D序列,使得标准Transformer架构可以直接处理图像数据,允许ViT像处理文本序列一样处理图像,充分利用了Transformer的自注意力机制来捕捉图像中的全局依赖关系。

下面我们用一个示例来说明PatchEmbedding的过程。

输入一张:256x256的rgb图像,然后把它分割为64个32x32的patchs,对patchs进行线性投影得到序列长度为64,dim为1024的Embedding,然后加上用于分类的可训练的classToken(随机初始化),最后在加上相同形状的PosEmbedding 作为TransformEncodeer的输入。

图片图片

图片来自:详解 Vision Transformer

图片图片

不同于Transfromer的PositionEmbedding(采用sin和cos固定编码),VIT中的PositionEmbedding采用了符合正态分布随机初始化,可训练的方案(bert也采用了类似方式)

论文中对学习到的positional embedding进行了可视化,发现相近的patchs的positional embedding比较相似,而且同行或同列的positional embedding也相近:

图片图片

需要注意的是:如果改变图像的输入大小,ViT不会改变patchs的大小,patchs的数量会发生变化,之前学习的pos_embed就维度对不上了,通常ViT采用插值的方式来解决这个问题,但效果不好,另外一篇论文给出了说明和解决措施 https://arxiv.org/pdf/2102.10882,有兴趣可以进一步研究下。

四、实验效果

ViT的训练策略:先在大数据集上做预训练,然后在小数据集上做迁移使用。

图片图片

如果在小数据集ImageNet上做预训练时,VIT的模型架构效果普遍低于ResNet搭建的BiT网络;当在中等数据集ImageNet-21k上做预训练时,VIT的模型架构基本位于BiT最好和最差的之间;而当在大数据集JFT-300M上做预训练时,VIT的模型架构最好的效果已经超过了BiT。

结论:VIT模型需要在大数据集上进行预训练,在大数据集上预训练的效果会比卷积神经网络的上限高

例如下图先在有3亿张图像的JFT大数据集上预训练,然后在ImageNet上进行微调,准确率达到88.55%

图片图片

ViT 还可根据 Attention Map 来可视化,得知模型具体关注图像的哪个部分,

图片图片

五、代码解析

源码地址:https://github.com/lucidrains/vit-pytorch

图片图片

图片来自:Vision Transformer详解

3.1、调用

代码语言:javascript复制
import torchfrom vit_pytorch import ViT
def test():    #VIT的具体实现在vit.py中    v = ViT(        #原始图像尺寸        image_size = 256,        #切割的每个图像块的尺寸        patch_size = 32,        #类别数量        num_classes = 1000,        #Transformer隐变量维度大小        dim = 1024,        #Transformer Encoder层的个数        depth = 6,        #Multi-Head Attention 头的个数        heads = 16,        #mlp层 hid层的维度        mlp_dim = 2048,        dropout = 0.1,        emb_dropout = 0.1    )
    img = torch.randn(1, 3, 256, 256)
    preds = v(img)

3.2、Attention和FFN的实现

代码语言:javascript复制
# helpers#确保t为元组def pair(t):    return t if isinstance(t, tuple) else (t, t)
# classes#前馈网络class FeedForward(nn.Module):    def __init__(self, dim, hidden_dim, dropout = 0.):        super().__init__()        self.net = nn.Sequential(            nn.LayerNorm(dim),            nn.Linear(dim, hidden_dim),            nn.GELU(),            nn.Dropout(dropout),            nn.Linear(hidden_dim, dim),            nn.Dropout(dropout)        )
    def forward(self, x):        return self.net(x)
#VIT中的self-Attention实现,这里也是多头注意力机制class Attention(nn.Module):    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):        super().__init__()        inner_dim = dim_head *  heads #多头的个数heads:16 * 每个头的维度:64 =1024        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads        self.scale = dim_head ** -0.5 # dim_head =64, scale=1/8
        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)        self.dropout = nn.Dropout(dropout)        #to_qkv线性变化,将输入映射到一个三维空间,以便在多头注意力机制中生成QKV 输入特征维度为dim (1024),输出维度为inner_dim*3 (1024*3)        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) #dim:1024,inner_dim:1024
        self.to_out = nn.Sequential(            nn.Linear(inner_dim, dim),            nn.Dropout(dropout)        ) if project_out else nn.Identity()
    def forward(self, x):        x = self.norm(x)        #将输入数据x映射到三维空间,x.shape为[1,65,1024],to_qkv经过线性变换后输出维度为[1,65,1024*3]; chunk(3,-1)将最后一个维度分割为3个子张量,生成qkv元组        qkv = self.to_qkv(x).chunk(3, dim = -1)        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) #进行形状转换,生成[batchsize,heads,squcelen,dim] 值为[1,16,65,64]        #经典的attention计算, 把q和K的转置相乘除以缩放系数,得到相似性系数        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale        #沿最后一维度进行softmax归一化        attn = self.attend(dots)        attn = self.dropout(attn)        #attn[1, 16, 65, 65]点乘V [1, 16, 65, 64]输出[1, 16, 65, 64]        out = torch.matmul(attn, v)        out = rearrange(out, 'b h n d -> b n (h d)') #对多头进行concate,得到[1, 65, 1024]        return self.to_out(out)

3.3、Transfromer Encoder层的实现

代码语言:javascript复制
#VIT中Transfromer的实现,用到了Transformer的Encoder层. 和原始的Transfromer稍微有些差异,主要是layernormalization的位置class Transformer(nn.Module):    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):#dim:1024,depth:6;heads:16;dim_head:64;mlp_dim:2048;dropout:0.1        super().__init__()        self.norm = nn.LayerNorm(dim)        self.layers = nn.ModuleList([])        for _ in range(depth):            self.layers.append(nn.ModuleList([                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),                FeedForward(dim, mlp_dim, dropout = dropout)            ]))
    def forward(self, x):        for attn, ff in self.layers:            x = attn(x)   x #Attention进行残差            x = ff(x)   x #MLP进行残差
        return self.norm(x)

3.4、ViT的实现

代码语言:javascript复制
#入口Module,这里的posEmbedding没有使用固定编码,而是像bert一样可训练的. 把image切分成多个patch,展平进行to_patch_embedding处理class ViT(nn.Module):    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):        super().__init__()        image_height, image_width = pair(image_size)        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'        # num_patches =(256//32)*(256//32)=64;  patch_dim:3*32*32=3072; dim=1024        num_patches = (image_height // patch_height) * (image_width // patch_width)        patch_dim = channels * patch_height * patch_width        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'        #使用einops的Rearrange优雅地处理张量维度        self.to_patch_embedding = nn.Sequential(            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),#这里(h p1) (w p2)就相当于h与w变为原来的1/p1,1/p2            nn.LayerNorm(patch_dim),            nn.Linear(patch_dim, dim),#patch_dim3072,dim 1024 线性变换            nn.LayerNorm(dim),        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches   1, dim)) # 创建一个形状为 (1, 65, 1024) 的随机张量,VIT中PE和Transformer中positionEmbedding的定义不同,这里是一个可以训练的模块        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))#创建一个随机的张量(1,1,1024)的cls_token        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool        self.to_latent = nn.Identity()
        self.mlp_head = nn.Linear(dim, num_classes)
    def forward(self, img):        x = self.to_patch_embedding(img)        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)        x = torch.cat((cls_tokens, x), dim=1)        x  = self.pos_embedding[:, :(n   1)]        x = self.dropout(x)        #输入和输出的形状都是 torch.Size([1, 65, 1024])        x = self.transformer(x)         #这里的pool为cls分类,所以沿dim=1,取第1个数据        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]        #这里的to_latent目前就是一个恒等变换层nn.Identity(),即输入和输出每个任何变化,可以去掉,这里起到占位的作用        x = self.to_latent(x)        return self.mlp_head(x)

六、资料

1.论文VIT:https://arxiv.org/pdf/2010.11929

2.源码:https://github.com/lucidrains/vit-pytorch

3.timm/models/vision_transformer.py: https://github.com/huggingface/pytorch-image-4.models/blob/main/timm/models/vision_transformer.py

5.ViT论文逐段精读【论文精读】https://www.bilibili.com/video/BV15P4y137jb

6.Vision Transformer(vit)网络详解 https://www.bilibili.com/video/BV1Jh411Y7WQ

7.李宏毅-Transformer 

https://www.bilibili.com/video/av56239558

8.详解VisionTransformer

https://blog.csdn.net/qq_39478403/article/details/118704747

9.Vision Transformer详解  https://blog.csdn.net/qq_37541097/article/details/118242600

10.ViT代码超详细解读 https://blog.csdn.net/weixin_43334693/article/details/131836233

11.ViT PyTorch代码全解析(附图解)

https://blog.csdn.net/weixin_44966641/article/details/118733341

12.Vision Transformer(VIT)代码分析 https://blog.csdn.net/qq_38683460/article/details/127346916

13.ViT:视觉Transformer backbone网络ViT论文与代码详解 https://mp.weixin.qq.com/s/Nok5UQ2nzex94GXyrltiBg

14.可视化VIT中的注意力 https://mp.weixin.qq.com/s/O-56hxVa6Fgiz2YpjXTodQ

15."未来"的经典之作 ViT:transformer is all you need! https://www.cvmart.net/community/detail/4461

16.搞懂 Vision Transformer 原理和代码 https://mp.weixin.qq.com/s/ozUHHGMqIC0-FRWoNGhVYQ

17.3W字长文带你轻松入门视觉transformer https://zhuanlan.zhihu.com/p/308301901

18.Vision Transformer, LLM, Diffusion Model 超详细解读 (原理分析 代码解读) https://zhuanlan.zhihu.com/p/348593638

19.einops.repeat, rearrange, reduce优雅地处理张量维度 https://blog.csdn.net/qq_37297763/article/details/120348764

感谢你的阅读

接下来我们继续学习输出AI相关内容,欢迎关注公众号“音视频开发之旅”,一起学习成长。

欢迎交流

0 人点赞