Vision Transformer 必读系列之图像分类综述(二): Attention-based

2022-02-28 13:46:07 浏览数 (2)

号外号外!awesome-vit 上新啦,

欢迎大家 Star Star Star ~

https://github.com/open-mmlab/awesome-vit

前言

Vision Transformer 必读系列之图像分类综述(一):概述 一文中,我们对 Vision Transformer 在图像分类中的发展进行了概述性总结。本文则对其中涉及的 Attention-based 部分进行详细说明。下一篇文章则会对概述中涉及的其他部分进行说明。

ViT 进展汇总思维导图如下图所示:

注意:文中涉及到的思维导图,可以通过 https://github.com/open-mmlab/awesome-vit下载。

1. Transformer

论文题目:Attention is All You Need

论文地址:https://arxiv.org/abs/1706.03762

Transformer 结构是 Google 在 2017 年为解决机器翻译任务(例如英文翻译为中文)而提出,从题目中可以看出主要是靠 Attention 注意力机制,其最大特点是抛弃了传统的 CNN 和 RNN,整个网络结构完全是由 Attention 机制组成。为此需要先解释何为注意力机制,然后再分析模型结构。

1.1 Attention 注意力机制

人生来就有注意力机制,看任何画面,我们会自动聚焦到特定位置特定物体上。此处的 Attention 机制也是同一个含义,对于需要的任何模态,不管是图像、文本、点云还是其他,我们都希望网络通过训练能够自动聚焦到有意义的位置,例如图像分类和检测任务,网络通过训练能够自动聚焦到待分类物体和待检测物体上。

注意力机制不是啥新鲜概念,视觉算法中早已广泛应用,典型的如 SENet。

利用 Squeeze-and-Excitation 模块计算注意力权重概率分布,然后作用于特征图上实现对每个通道重加权功能。

可以举一个更简单的例子,假设有一个训练好的分类网络,输入一张图片,训练好的分类网络权重 W 和图片 X 进行注意力计算,从 X 中提取能够有助于分类的特征,该特征最终可以作为类别分类依据。W 和 X 都是矩阵,要想利用 W 矩阵来达到重加权 X 目的,等价于计算 W 和 X 的相似度(点乘),然后将该相似度变换为权重概率分布,再次作用于 X 上就可以

以一个简单猫狗二分类例子说明。网络最终输出是 2x1 的向量,第一个数大则表示猫类别,否则为狗类别,假设网络已经训练好了,其 W 为 shape 为 2x1 的向量,值为 [[0.1, 0.5]],X 表示输入图片 shape 也是 2x1,其值为 [[0.1, 0.8]],可以看出其类别是狗,采用如下的计算步骤即可正确分类:

  • W 和 X 的转置相乘,即计算 W 中每个值和 X 中每个值的相似度,得到 2x2 矩阵,值为 [[0.01,0.08], [0.05,0.4]]。
  • 第二个维度进行 Softmax,将其转化为概率权重图即为 [[0.4825, 0.5175], [0.4134, 0.5866]]。
  • 将上述概率权重乘以 X,得到 shape 为 2x1 输出,值为 [[0.4622, 0.5106]]。
  • 此时由于第二个值大,所以正确分类为狗。

X 是含有狗的图片矩阵,能够正确分类的前提是训练好的 W 矩阵中第二个数大于第一个数。可以简单理解上述过程是计算 W 和 X 的相似度,如果两个向量相似(都是第二个比第一个数大),那么就分类为狗,否则就分类为猫

代码语言:javascript复制


import torch

W = torch.tensor([[0.1, 0.5]]).view([1, -1, 1])
X = torch.tensor([[0.1, 0.8]]).view([1, -1, 1])
# 1 计算两个向量相似度
attn_output_weights = torch.bmm(W, X.transpose(1, 2))
# 2 转换为概率分布
attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
# 3 注意力加权
cls = torch.bmm(attn_output_weights, X)[0]
print(cls)

上述计算过程可以用如下公式表示:

对应到上面例子,Q 就是训练好的 W 矩阵,K 是图片输入,V 和 K 相等,其通用解释为利用 Q 查询矩阵和 K 矩阵进行相似度计算,然后转换为概率分布,此时概率值大的位置表示两者相似度大的部分,然后将概率分布乘上 V 值矩阵,从而用注意力权重分布加权了 V 矩阵,也就改变了 V 矩阵本身的分布。如果注意力机制训练的很好,那么提取的 V 应该就是我们想要的信息。分母 d_k 的平方根是为了避免梯度消失,当向量值非常大的时候,Softmax 函数会将几乎全部的概率分布都分配给了最大值对应的位置,也就是说所谓的锐化,通过除以分母可以有效避免梯度消失问题,稳定训练过程。

上述公式就是论文中提出的最重要的 Scaled Dot-Product Attention 计算公式,先利用点乘计算 QK 矩阵的相似度,除以分母 d_k 平方根进行 Scaled 操作,然后 Softmax 操作将其转换为概率乘以 V 实现 Attention 功能。

需要注意:为了让上面公式不会报错,其 Shape 关系必须为 Q - (N, M),K - (P, M),V - (P, G), 一般来说 K 和 V Shape 相同,但是 Q Shape 不一定和 K 相同。通过灵活地改变这些维度就可以控制注意力层的计算复杂度,后续大部分改进算法都有利用这一点

1.2 Transformer 结构分析

Transformer 是为了解决机器翻译任务而提出。机器翻译是一个历史悠久的问题,可以理解为序列转序列问题,也就是我们常说的 seq2seq 结构,解决这类问题一般是采用 encoder-decoder 结构,Transformer 也沿用了这种结构。翻译任务一个常规的解决方案如下所示:

对应到 Transformer 中的一个更具体的结构为:

主要包括编码器和解码器组件,编码器包括自注意力模块(QKV 来自同一个输入)和前向网络,解码器和编码器类似,只不过内部多了编码器和解码器交互的交叉注意力模块。

通常来说,标准的 Transformer 包括 6 个编码器和 6 个解码器串行。

  1. 编码器内部接收源翻译输入序列,通过自注意力模块提取必备特征,通过前向网络对特征进行进一步抽象。
  2. 解码器端输入包括两个部分,一个是目标翻译序列经过自注意力模块提取的特征,一个是编码器提取的全局特征,这两个输入特征向量会进行交叉注意力计算,抽取有利于目标序列分类的特征,然后通过前向网络对特征进行进一步抽象。
  3. 堆叠多个编码器和解码器,下一个编解码器接收来自上一个编解码的输出,构成串行结构不断抽取,最后利用解码器输出进行分类即可。

Transformer 完整结构如下所示:

编码器基本组件包括:源句子词嵌入模块 Input Embedding、位置编码模块 Positional Encoding、多头自注意力模块 Muti-Head Attention、前向网络模块 Feed Forward 以及必要的 Norm、Dropout 和残差模块。

解码器基本组件类似包括:目标句子词嵌入模块 Output Embedding、位置编码模块 Positional Encoding、带 mask 的自注意力模块 Masked Muti-Head Attention、交叉互注意力模块 Muti-Head Attention、前向网络模块 Feed Forward 、分类头模块 Linear Softmax 以及必要的 Norm、Dropout 和残差模块。

由于本文重点是分析视觉方面的 Transformer,故没有必要对机器翻译过程进行深入解析,读者只需要理解每个模块的作用即可,而且视觉分类 Transformer 任务和 NLP 机器翻译任务不一样,实际上也不需要解码器模块,相比 NLP 任务会简单很多。

1.2.1 编码器基本组件

(1) 源句子词嵌入模块 Input Embedding

机器翻译是句子输入,句子输出,每个句子由单词构成,将句子编码成程序可以理解的向量过程就叫做词嵌入过程,也就是常说的 Word2Vec,对应到图像中称为 Token 化过程即如何将图像转换为更具语义的 Token,Token 概念会在 ViT 中详细描述。

(2) 多头自注意力模块 Muti-Head Attention

在 1.1 小节已经详细说明了注意力计算过程。左边是最简单的 Scaled Dot-Product Attention,单纯看上图你可以发现没有任何可学习参数,那么其存在的意义是啥?实际上可学习参数在 QKV 映射矩阵中,在自注意力模块中会对输入的向量分别乘上可学习映射矩阵 W_Q、W_K 和 W_V,得到真正的 Q、K 和 V 输入,然后再进行 Scaled Dot-Product Attention 计算。

为了增加注意力特征提取的丰富性,不会陷入某种局部特性中,一般会在注意力层基础上(单头注意力层)引入多个投影头,将 QKV 特征维度平均切分为多个部分(一般分成 8 部分),每个部分单独进行自注意力计算,计算结果进行拼接 。在特征维度平均切分,然后单独投影、计算,最后拼接可以迫使提取的注意力特征更加丰富。也就是上面的多头注意力模块 Multi-Head Attention。

(3) 前向网络模块 Feed Forward

前向网络模块主要是目的是对特征进行变换,其单独作用于每个序列(只对最后一个特征维度进行变换)。由于没有结构图,故直接贴相关代码,包括两个 Position-wise FC 层、激活层、Dropout层和 LayerNorm 层。

代码语言:javascript复制


class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        # 两个 fc 层,对最后的维度进行变换
        self.w_1 = nn.Linear(d_in, d_hid) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x  = residual

        x = self.layer_norm(x)

        return x

(4) Norm、Dropout 和残差模块

在每个注意力层后面和前向网络后都会接入 Dropout 、残差模块和 Layer Norm 模块。这些必要的措施对整个算法的性能提升非常关键。至于为啥用 Layer Norm 而不是 Batch Norm ,原因是机器翻译任务输入的句子不一定是同样长的,组成 Batch 训练时候会存在大量 Padding 操作,如果在 Batch 这个维度进行 Norm 会出现大量无效统计,导致 Norm 值不稳定,而 Layer Norm 是对每个序列单独计算,不考虑 Batch 影响,这样比较符合不定长序列任务学习任务。当然如果换成图像分类任务,则可以考虑使用 BN 层,后续有算法是直接采用 BN 的。

(5) 位置编码 Positional Encoding

考虑一个分类任务,输入一段句子判断是疑问句还是非疑问句?现在有两条语句分别是:

  • 不准在地铁上吃东西
  • 在地铁上吃东西准不

自注意力层的计算不会考虑字符间的顺序,因为每个字符都是单独和全局向量算相似度,也就是说上面两个句子输入进行注意力计算,输出的向量值是相同的,只不过相对位置有变化。如果我们对输出向量求和后值大于 0 还是小于 0 作为分类依据,那么上面两个句子输出相加值是完全相同的,那就始终无法区分到底是疑问句还是非疑问句,这就是我们常说的 Transformer 具有位置不变性。要解决这个问题,只需要让模型知道输入语句是有先后顺序的,位置编码可以解决这个问题。

加入位置信息的方式非常多,最简单的可以是直接将输入序列中的每个词按照绝对坐标 0,1,2 编码成相同长度的向量,然后和词向量相加即可。作者实际上提出了两种方式:

  • 网络自动学习,直接全 0 初始化向量,然后和词向量相加,通过网络学习来学习位置信息。
  • 自己定义规则,规则自己定,只要能够区分输入词顺序即可,常用的是 sincos 编码。

实际训练选择哪一种位置编码方式发现效果一致,但是不管哪一种位置编码方式都应该充分考虑在测试时候序列不定长问题,可能会出现测试时候非常长的训练没有见过的长度序列,后面会详细说明

1.2.2 解码器基本组件

其大部分组件都和编码器相同,唯一不同的是自注意力模块带有 mask,还额外引入了一个交叉注意力模块以及分类头模块。

(1) 带 mask 的自注意力模块

注意这个模块的输入是目标序列转化为词向量后进行自注意力计算。机器翻译是一个 seq2seq 任务,其真正预测是:最开始输入开始解码 token 代表解码开始,解码出第一个词后,将前面已经解码出的词再次输入到解码器中,按照顺序一个词一个词解码,最后输出解码结束 token,表示翻译结束。

也就是当解码时,在解码当前词的时候实际上不知道下一个词是啥,但在训练时,是将整个目标序列一起输入,然而注意力计算是全局的,每个目标单词都会和整个目标句子计算自注意力,这种训练和测试阶段的不一致性无法直接用于预测。为此我们需要在训练过程中计算当前词自注意力时候手动屏蔽掉后面的词,让模型不知道后面词。具体实现就是输入一个 mask 来覆盖掉后面的词。

由于这种特性是只存在于 NLP 领域,图片中不存在,故不再进行更深入分析。

(2) 交叉注意力

交叉注意力模块和自注意力模块相同,只不过其 QKV 来源不同,Q 来自解码器,KV 来自编码器,交叉注意力模块会利用 Q 来提取编码器提取的特征 KV,然后进行分类。

(3) 分类头

分类头就是普通的线性映射,转换输出维度为分类个数,然后采用 CE Loss 进行训练即可。

1.3 总结

Transformer 结构内部存在多个组件,但是最核心的还是注意力模块,在原始论文中作者也引入了大量的可视化分析来证明注意力模块的作用,有兴趣的建议阅读原文。可能作者自己也没有想到这篇论文会在视觉领域引起另一个全新的风尚,开辟出一条新的看起来前途一片光明的道路。

图片来自 A Survey of Visual Transformers

网址:https://arxiv.org/abs/2111.06091

2. Vision Transformer

论文题目:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

论文地址:https://arxiv.org/abs/2010.11929

ViT 是第一篇成功将 Transformer 引入到视觉领域且成功的尝试,开辟了视觉 Transformer 先河。其结构图如下所示:

其做法非常简单,简要概况为:

  • 将图片分成无重叠的固定大小 Patch (例如 16x16),然后将每个 Patch 拉成一维向量, n 个 Patch 相当于 NLP 中的输入序列长度(假设输入图片是 224x224,每个 patch 大小是 16x16,则 n 是 196),而一维向量长度等价于词向量编码长度(假设图片通道是 3, 则每个序列的向量长度是 768)。
  • 考虑到一维向量维度较大,需要将拉伸后的 Patch 序列经过线性投影 (nn.Linear) 压缩维度,同时也起到特征变换功能,这两个步骤可以称为图片 Token 化过程 (Patch Embedding)。
  • 为了方便后续分类,作者还额外引入一个可学习的 Class Token,该 Token 插入到图片 token 化后所得序列的开始位置。
  • 将上述序列加上可学习的位置编码输入到 N 个串行的 Transformer 编码器中进行全局注意力计算和特征提取,其中内部的多头自注意模块用于进行 Patch 间或者序列间特征提取,而后面的 Feed Forward (Linear GELU Dropout Linear Dropout) 模块对每个 Patch 或者序列进行特征变换。
  • 将最后一个 Transformer 编码器输出序列的第 0 位置( Class Token 位置对应输出)提取出来,然后后面接 MLP 分类后,然后正常分类即可。

可以看出,图片分类无需 Transformer 解码器,且编码器几乎没有做任何改动,针对图像分类任务,只需单独引入一个 Image to Token 操作和 Class Token 的概念即可。

如何理解 Token?个人觉得任何包括图片更加高级的语义向量都可以叫做 Token,这个概念在 NLP 中应用非常广泛,表征离散化后的高级单词语义,在图像中则可以认为是将图像转化为离散的含更高级语义的向量。

ViT 证明纯 Transformer 也可以取得非常好的效果,相比 CNN 在数据量越大的情况下优势更加明显,但是 ViT 也存在如下问题:

  • 不采用超大的 JFT-300M 数据集进行预训练,则效果无法和 CNN 媲美,原因应该是 Transformer 天然的全局注意力计算,没有 CNN 这种 Inductive Bias 能力,需要大数据才能发挥其最大潜力。
  • ViT 无法直接适用于不同尺寸图片输入,因为 Patch 大小是固定的,当图片大小改变,此时序列长度就会改变,位置编码就无法直接适用了,ViT 解决办法是通过插值,这种做法一般会造成性能损失,需要通过 Finetune 模型来解决,有点麻烦。
  • 因为其直筒输出结构,无法直接应用于下游密集任务。

后面的文章对上述缺点采用了各种各样的改进,并提出了越来越先进的处理手段,推动了视觉 Transformer 的巨大进步。

3. 全局概述

由于内容非常多,为了更容易理解,我在拆解模块的基础上对每个模块进行分析,而不是对某篇文章进行概括。综述部分的分析流程按照结构图顺序描述,我将近期图像分类 Vision Transformer 发展按照 ViT 中是否包括自注意层模块来划分,包括:

  1. Attention-based, 这类算法是目前主流研究改进方向,包括了 Transformer 中最核心的自注意力模块。
  2. MLP-based, 这类算法不需要核心的自注意力模块,而是简单的通过 MLP 代替,也可以取得类似效果。
  3. ConvMixer-based,这类算既不需要自注意力模块,也不是单纯依靠 MLP,而是内部混合了部分 Conv 算子来实现类似功能。
  4. General architecture analysis,在这三类算法基础上也有很多学者在探讨整个 Transformer 架构,其站在一个更高的维度分析问题,不局限于是否包括自注意力模块,属于整体性分析。

在上述三个方向中,Attention-based 是目前改进最多,最热门的,也是本综述的核心。本文按照 3 个分类顺序依次分析,最后进行通用架构分析。通过 General architecture analysis 部分可以深化Attention-based、MLP-based 和 ConvMixer-based 三者的联系和区别。本文仅仅涉及 Attention-based 部分。

3.1 Attention-based

Attention-based 表示这类算法必然包括注意力模块,我们将按照广度优先顺序进行一次分析。

继 ViT 后,我们将其发展分成两条线路:训练策略和模型改进。

其中训练策略表示目前主流对 ViT 模型的训练改进方式,而模型改进则是对各个部件进行改进。

  • 训练策略包括两篇论文:DeiT 和 Token Labeling。两者提出的出发点一致,都是为了克服 ViT 需要 JFT-300M 大数据集进行预训练的缺点。DeiT 是通过引入蒸馏学习解决,而 Token Labeling 通过引入显著图然后施加密集监督解决。后续发展中大部分算法都是参考了 DeiT 的训练策略和超参设置,具有非常大的参考价值。
  • 模型改进方面,我将其分成了 6 个组件以及其他方面的改进,6 个组件包括:
  1. Token 模块,即如何将 image 转 token 以及 token 如何传递给下一个模块
  2. 位置编码模块
  3. 注意力模块,这里一般都是自注意力模块
  4. Fead Forward (FFN) 模块
  5. Norm 模块位置
  6. 分类预测头模块

下面按照训练策略和模型改进顺序分析。

3.1.1 训练策略

训练策略解决 ViT 需要大数据先预训练问题以及超参有待优化问题。

3.1.1.1 DeiT

如果说 ViT 开创了 Transformer 在视觉任务上面的先河,那么 DeiT 的出现则解决了 ViT 中最重要的问题:如果不采用超大的 JFT-300M 数据集进行预训练,则效果无法和 CNN 媲美。在单个节点 8 张 V100 且无需额外数据的情况下,用不到 3 天的时间训练所提的 ViT(86M 参数),在 ImageNet 上单尺度测试达到了 83.1% 的 top-1 准确率。

DeiT 核心是引入蒸馏手段加上更强的 Aug 和更优异的超参设置。其蒸馏的核心做法如下所示:

ViT 的 Class Token 是加到图片输入序列的前面,那么蒸馏 Token 可以插到输入序列的后面,当然插入到哪个位置其实无所谓,你也可以插入到 Class Token 后面,经过 Transformer 编码器输出的序列相比 ViT 也会多一个,然后额外的一个输出 Token 经过线性层输出相同类别通道,最后进行蒸馏学习。

对于蒸馏学习来说,做法通常有两个:

  • Soft 蒸馏,即学生模型和教师模型预测的 Softmax 概率分布值计算 KL Loss。
  • Hard 蒸馏,即教师模型预测的 Softmax 概率分布值中,值最大对应的类别作为标签,然后和学生模型预测的 Softmax 概率分布值计算 CE Loss。

蒸馏学习中,通常教师模型会选择一个比学生模型性能更强的且已经提前训练好的模型,教师模型不需要训练,通过蒸馏 loss 将教师模型知识以一种归纳偏置的方式转移给学生模型,从而达到提升学生模型性能的目的。因为引入了额外的蒸馏 Token,而且该 Token 训练任务也是分类,所以实际上 DeiT 在推理时,是将 Class Token 和 Distillation Token 的预测向量求平均,再转换为概率分布。

为了证明 Distillation Token 的有效性,而不是只由于多了一个 Token 或者说多了一个可学习参数导致的,作者还做了对比试验,不加 Distillation Token,而是再加一个 Class Token,相当于有两个分类头,两个 Token 独立且随机初始化,实验发现他们最终收敛后两个分类 Token 的相似度达到 0.999,并且性能更弱,这样证明了加入 Distillation Token 的意义。

通过大量实验,作者总结了如下结论:

  • 蒸馏做法确实有效,且 Hard 蒸馏方式效果会更好,泛化性能也不错。
  • 使用 RegNet 作为教师网络可以取得更好的性能表现,也就是说相比 Transformer,采用卷积类型的教师网络效果会更好。

除了上述蒸馏策略,还需要特别注意 DeiT 引入了非常多的 Aug 并且提供了一套更加优异的超参,这套参数也是后续大部分分类模型直接使用的训练参数,非常值得学习,如下所示:

总而言之, DeiT 算法非常优异,实验也非常多(建议去阅读下),最大贡献是通过蒸馏策略省掉了 ViT 需要 JFT-300M 数据集进行预训练这个步骤,并且提供了一套非常鲁棒且实用的超参配置,深深地影响了后续的大部分图像分类视觉 Transformer 模型。

3.1.1.2 Token Labeling

DeiT 不是唯一一个解决 ViT 需要大数据量问题的算法,典型的还有 Token Labeling,其在 ViT 的 Class Token 监督学习基础上,还对编码器输出的每个序列进行额外监督,相当于将图片分类任务转化为了多个输出 Token 识别问题,并为每个输入 Patch 的预测 Token 分配由算法自动生成的基于特定位置的监督信号,简要图如下所示:

从上图明显可以看出,相比 ViT 额外多了输出 Token 的监督过程,这些监督可以当做中间监督,监督信息是通过 EfficientNet 或者 NFNet ( F6 86.3% Top-1 accuracy) 这类高性能网络对训练图片提前生成的显著图,每个显著图维度是和类别一样长的 C 维,辅助 Loss 和分类一样也是 CE Loss。当然最终实验结果表明性能比 DeiT 更优异,而且由于这种密集监督任务,对于下游密集预测任务泛化性也更好。

在此基础上 DeiT 已经证明通过对 ViT 引入更多的强 Aug 可以提升性能,例如引入 CutMix,但是本文的做法无法直接简单增加 CutMix,为此作者还专门设计了一个 MixToken,大概做法是在 Pathc Embedding 后,对 Token 进行了相应的 CutMix 操作。性能表如下所示:

LV-ViT 即为本文所提模型。相比 DeiT,作者认为本文做法更加优异,体现在:

  • 不需要额外的教师模型,是一个更加廉价的做法。
  • 相比于单向量监督,以密集的形式监督可以帮助训练模型轻松发现目标物体,提高识别准确率,实验也证明了对下游密集预测任务(例如语义分割)更友好。

下表是对训练技术的详细分析:

简而言之,Token Labeling 的核心做法是通过引入额外的显著图来监督每个 patch 输出的预测 token,虽然不需要教师模型,但是依然需要利用更优异的模型对所有训练图片生成显著图。

3.1.2 模型改进

在 DeiT 提出后,后续基于它提出了大量的改进模型,涉及到 ViT 的方方面面。前面说过 ViT 模型主要涉及到的模块包括:

  1. Token 模块,即如何将 image 转 token 以及 token 如何传递给下一个模块
  2. 位置编码模块
  3. 注意力模块,这里一般都是自注意力模块
  4. Fead Forward (FFN) 模块
  5. Norm 模块位置
  6. 分类预测模块
3.1.2.1 Token 模块

Token 模块包括两个部分:

  1. Image to Token 模块即如何将图片转化为 Token,一般来说分成有重叠和无重叠的 Patch Embedding 模块
  2. Token to Token 模块即如何在多个 Transformer 编码器间传递 Token,通常也可以分成固定窗口 Token 化过程和动态窗口 Token 化过程两个

下面是完整结构图:

3.1.2.1.1 Image to Token 模块

首先需要明确:Patch Embedding 通常包括图片窗口切分和线性嵌入两个模块,本小结主要是说图片窗口切分方式,而具体实现不重要,常用的 2 种实现包括 nn.Conv 和 nn.Unfold,只要设置其 kernel 和 stride 值相同,则为非重叠 Patch Embedding,如果 stride 小于 kernel 则为重叠 Patch Embedding。

(1) 非重叠 Patch Embedding

ViT 和目前主流模型例如 PVT 和 Swin Transformer 等都是采用了非重叠 Patch Embedding,其简要代码为:

代码语言:javascript复制

# 非重叠只需要设置Conv kernel_size 和 stride 相同即可
_conv_cfg = dict(
    type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims)


x = self.projection(x).flatten(2).transpose(1, 2)

通过设置 16x16 的 kernel 和 stride 可以将图片在空间维度进行切割,每个 patch 长度为 16x16x3,然后将每个 Patch 重排拉伸为一维向量后,经过线性层维度变换,输出 shape 为 (B, Num_Seq, Dim)。

在 TNT 中作者提出了一种更精细的非重叠 Patch Embedding 模块,如下图所示:

他的基本观点是自然图像的复杂度相较于自然文本更高,细节和颜色信息更丰富,而 ViT 的非重叠 Patch Embedding 做法过于粗糙,因为后续自注意力计算都是在不同 Patch 间,这会导致 Patch 内部的局部自注意力信息没有充分提取,而这些信息在图像中也是包含了不同的尺度和位置的物体特征,是非常关键的。故我们不仅要考虑 Patch 间自注意力,还要考虑 Patch 内自注意力,为此作者在 外层 Transformer 中又嵌入了一个内层 Transformer,相应的非重叠 Patch Embedding 也分成两步:整图的非重叠 Patch Embedding 过程和 Patch 内部更细粒度的非重叠 Patch Embedding 过程。

通过上图大概可以看出其具体做法,内部相当于有两个 Transformer,第一个 Transformer (Outer Transformer )和 ViT 完全一样,处理句子 Sentences 信息即图片 Patch 级别信息,第二个 Transformer (Inner Transformer,也需要额外加上 Inner Transformer 所需要的位置编码) 处理更细粒度的 Words 信息即图片 Patch 内再切分为 Patch,为了能够将两个分支信息融合,内部会将 Inner Transformer 信息和 Outer Transformer 相加。将上述 Transformer block 嵌入到 PVT 模型中验证了其对下游任务的适用性,通过进一步的可视化分析侧面验证了分类任务上 TNT 相比 DeiT 的优异性。

(2) 重叠 Patch Embedding

在常规的 CNN 网络中一般都是采用重叠计算方式,为此是否采用重叠 Patch Embedding 会得到更好的性能?

直接将非重叠 Patch Embedding 通过修改 Unfold 或者 Conv 参数来实现重叠 Patch Embedding 功能的典型算法包括 T2T-ViT 和 PVTv2,这两个算法的出发点都是非重叠 Patch Embedding 可以加强图片 Patch 之间的连续性,不至于出现信息断层,性能应该会比重叠 Patch Embedding 高。PVTv2 内部采用 Conv 实现,而 T2T ViT 是通过 Unfold 方式实现(论文中称为 soft split)。

前面说过 CNN 网络中一般都是采用重叠计算方式,那么是否可以用 ResNet Stem 替换非重叠 Patch Embedding过程,性能是否会更好?

在 Early Convolutions Help Transformers See Better 论文中,作者进行了深度分析,虽然作者只是简单的将图片 Token 化的 Patch Embedding 替换为 ResNet Conv Stem,但是作者是从优化稳定性角度入手,通过大量的实验验证上述做法的有效性。作者指出 Patch Embedding 之所以不稳定,是因为该模块是用一个大型卷积核以及步长等于卷积核的卷积层来实现的,往往这个卷积核大小为 16*16,这样的卷积核参数量很大,而且随机性很高,从某种程度上造成了 Transformer 的不稳定,如果用多个小的卷积来代替则可以有效缓解。结构如下所示:

考虑了和 ViT 公平对比,新引入的 Conv Stem 计算量约等于一个 transformer block,故后续仅仅需要 L-1 个 transformer block。作者通过大量分析实验得到一些经验性看法:

(a) ViT 这类算法对 lr 和 wd 超参的选择非常敏感,而替换 Stem 后会鲁棒很多。

(b) ViT 这类算法收敛比较慢,而本算法会快很多,例如都在 100 epoch 处本文性能远优于 ViT。

ViT_p 即为 Patch Embedding 模式,ViT_c 即为 Conv Stem 模式,可以看出在不同 flops 下模型收敛速度都是 ViT_c 快于 ViT_p,虽然到 400 epoch 时候性能都非常接近。

(c) ViT 这类算法只能采用 AdamW 训练,而本文更加通用,采用 SGD 后性能没有显著下降。

众所周知,ViT 类模型都只能用 AdamW 训练,其占据显存是 SGD 的 3 倍,所以一般在 CNN 网络中都是用过 SGD 模型,性能通常不错,而通过替换 Patch Embedding 后也可以用 SGD 训练了。

(d) 仅仅采用 ImageNet 训练 ViT 性能难以超越 CNN,而本文可以进一步提升 ViT 性能。

与上述论文持相同观点的也包括 ResT 、Token Learner、CSWin Transformer 等算法,他们都采用了完全相同的做法。更进一步在 PS-ViT 中为了能够方便后续的渐进采样模块稳定提取更好的特征点,作者在 Image to Token 模块中不仅仅引入了 ResNet 的 Conv Stem 模块,还在后面再使用了 ResNet 第一个 stage 的前两个残差 block,在 Token to Token 模块中会详细说明 PS-ViT。

在 CeiT 中作者出发点是 CNN 中的诸多特性已经被证明是很成功的,纯粹的 Transformer 需要大量的数据、额外的监督才能达到和 CNN 相同的精度,出现这种问题的原因可能是 NLP 中的 Transformer 直接搬到图像任务中可能不是最合适的,应该考虑部分引入 CNN 来增强 Transformer。具体来说,在图片转 Token 方案中提出 Image-to-Tokens (I2T) 模块,不再是从图片中直接进行 Patch Emeding ,而是对 CNN 和 Pool 层所提取的底层特征进行 Patch Embedding,借助图像特征会比直接使用图片像素更好。

上图的上半部分是 ViT 的 Patch Embedding 过程,下图是 CeiT 所提出的做法,核心就是引入卷积操作提取底层特征,然后在底层特征上进行 Patch Embedding 操作。

既然采用 Conv Stem 可以解决很多问题,那么理论上经过精心设计的 Conv 结构也必然是有效的,例如 ViTAE 中就采用了空洞卷积做法,本质上是希望能够利用卷积提供多尺度上下文信息,这有助于后续模块信息提取,如下图所示:

对图片或者特征图应用多个不同空洞率的卷积提取信息后,进行拼接和 GeLU 激活函数后,直接拉伸为一维向量,从而转换为序列,并且由于空洞卷积可以实现下采样功能,故也可以有效地减少后续注意力模块计算量。

3.1.2.1.2 Token to Token 模块

大部分模型的 Token to Token 方案和 Image to Token 做法相同,但是也有些算法进行了相应改造。经过整理,我们将其分成两种做法:

  1. 固定窗口 Token 化
  2. 动态窗口 Token 化

固定窗口是指 Token 化过程是固定或者预定义的规则,典型的重叠和非重叠 Patch Embedding 就是固定窗口,因为其窗口划分都是提前订好的规则,不会随着输入图片的不同而不同,而动态窗口是指窗口划分和输入图片语义相关,不同图片不一样,是一个动态过程。

(1) 固定窗口 Token 化

这个做法通常和 Image to Token 模块完全一样,也可以分成非重叠 Patch Embedding 和重叠 Patch Embedding,大部分算法都属于这一类,例如 PVT、Swin Transformer 等。

(2) 动态窗口 Token 化

动态窗口 Token 化过程典型代表是 PS-ViT 和 TokenLearner。

前面说过,Vision Transformer with Progressive Sampling (PS-ViT) 中为了方便后续的渐进采样模块能够稳定提取更好的特征点,在 Image to Token 模块中不仅仅引入了 ResNet 的 Conv Stem 模块,还在后面再使用了 ResNet 第一个 stage 的前两个残差 block。在特征图 F 后,作者在 Token to Token 环节引入了一个渐进式采样模块,其出发点是 ViT 采用固定窗口划分机制,然后对每个窗口进行 Token 化,这种做法首先不够灵活,而且因为图片本身就是密集像素,冗余度非常高,采用固定划分方法对于分类来说可能就某几个窗口内的 Token 实际上才是有意义的,假设物体居中那么物体四周的 Token 可能是没有作用的,只会增加无效计算而已。基于此作者设计一个自适应采样的 Token 机制,不再是固定的窗口采样,而是先初始化固定采样点,如下图红色点所示,然后通过 refine 机制不断调整这些采样点位置,最终得到的采样点所对应的 Token 就是最有代表力的。其完整分类网络结构图如下所示:

得到特征图 F 后,经过渐进采样模块,不断 refine 采样点,最终输出和采样点个数个序列,将该序列作为 ViT 模型的输入即可。简单来看渐进采样模块起到了 Token to Token 作用。其中的渐进采样模块结构图如下所示:

详细计算过程如下:

  1. 首先图片经过 ResNet Conv Stem ResNet 第一个 stage 的前两个残差块进行特征提取,得到 F。
  2. 在特征图或者原图上先设置初始均匀固定间隔采样点 pt,上图是 9 个采样点,表示最终序列长度是 9。
  3. 利用 pt 值对 F 进行采样,提取对应位置的特征向量,加上位置编码输入到编码器中,输出 T_t。
  4. 将 T_t 经过一个 FC 层生成 offset,将该 offset 和初始位置 pt 相加就可以得到 refine 后的 p_t 1。
  5. 将 3-4 步骤重复 N 次,下一个采样模块的输入包括 refine 后的 pt、特征图 F 和上一个采样模块的输出 T,三者相加。
  6. 经过 N 次 refine 后,将该 token 序列拼接上 class token,然后再经过 M 个编码器模块。
  7. 最后对 class token 对应位置输出 token 进行分类训练即可。

可以发现,和 ViT 的主要差异就在于其采样点不是固定的均匀间隔,而是基于语义图自适应,从而能够在减少计算量的前提下进一步提升性能。PS-ViT 在 top-1 精度方面比普通 ViT 高 3.8%,参数减少约 4 倍,FLOP 减少约 10 倍,性能比较优异。

基于类似出发点,TokenLearner 提出可以基于空间注意力自适应地学习出更具有代表性的 token,从而可以将 ViT 的 1024 个 token 缩减到 8-16 个 token,计算量减少了一倍,性能依然可以保持一致,这也侧面说明了 ViT 所采样的固定窗口 token 有大量冗余,徒增计算量而已。其核心示意图如下所示:

假设想仅仅采样出 8 个 token,首先采用 Conv Stem 提取图片特征,然后分别输入到 8 个空间注意力分支中,空间注意力分支首先会应用一系列卷积生成空间 attention 图,然后逐点和输入特征相乘进行特征加权,最后通过空间全局 pool 层生成 1x1xC 的 Token,这样就将 HXWXC 的特征图转换为了 8 个通道为 C 的 Token。

为了进一步提高信息,作者还额外提出一个 TokenFuser 模块,加强 Token 和 Token 之间的联系以及恢复空间结构,整个分类网络的结构如下所示:(a) 为不包括 TokenFuser 的改进 ViT 结构,(b) 为包括 TokenFuser 的改进 ViT 结构。

从上述结构可以发现, TokenLearner 模块起到了自适应提取更具语义信息的 Token,并且还能够极大地减少计算量,而 TokenFuser 可以起到加强 Token 和 Token 之间的联系以及恢复空间结构的功能,TokenLearner Transformer TokenFuser 构成 Bottleneck 结构。其中 TokenFuser 示意图如下所示:

其接收两个输入,一个是 TokenLearner 前的保持了空间信息的 1024 个 token 特征,一个是 TokenLearner 后经过自适应采样的 8 个 token 特征,然后以注意力模式两者进行乘加操作,融合特征以及恢复空间结构。

作者的分类实验依然采用了 JFT-300M 数据集进行预训练,然后在 ImageNet 1k上面微调,也就是说和最原始的 ViT 进行比较。

TokenFuser 也进行了相应的对比实验。

at 6 表示 TokenLearner 插入到第 6 个 Transformer Encoder 后。

3.1.2.2 位置编码模块

位置编码模块是为 Transformer 模块提供 Patch 和 Patch 之间的相对关系,非常关键。在通用任务的 Transformer 模型中认为一个好的位置编码应该要满足以下特性:

  • 不同位置的位置编码向量应该是唯一的
  • 不能因为不同位置位置编码的值大小导致网络学习有倾向性
  • 必须是确定性的
  • 最好能够泛化到任意长度的序列输入

ViT 位置编码模块满足前 3 条特性,但是最后一条不满足,当输入图片改变时候需要微调,比较麻烦。基于此也出现了不少的算法改进,结构图如下所示:

按照是否显式的设置位置编码向量,可以分成:

  1. 显式位置编码,其中可以分成绝对位置编码和相对位置编码。
  2. 隐式位置编码,即不再直接设置绝对和相对位置编码,而是基于图片语义利用模型自动生成能够区分位置信息的编码向量。

隐式位置编码对于图片长度改变场景更加有效,因为其是自适应图片语义而生成。

3.1.2.2.1 显式位置编码

显式位置编码,可以分成绝对位置编码和相对位置编码。

(1) 绝对位置编码

绝对位置编码表示在 Patch 的每个位置都加上一个不同的编码向量,其又可以分成固定位置编码即无需学习直接基于特定规则生成,常用的是 Attention is all you need 中采用的 sincos 编码,这种编码方式可以支持任意长度序列输入。还有一种是可学习绝对位置编码,即初始化设置为全 0 可学习参数,然后加到序列上一起通过网络训练学习,典型的例如 ViT、PVT 等等。

(2) 相对位置编码

相对位置编码考虑为相邻 Patch 位置编码,其实现一般是设置为可学习,例如 Swin Transformer 中采用的可学习相对位置编码,其做法是在 QK 矩阵计完相似度后,引入一个额外的可学习 bias 矩阵,其公式为:

Swin Transformer 这种做法依然无法解决图片尺寸改变时候对相对位置编码插值带来的性能下降的问题。在 Swin Transformer v2 中作者做了相关实验,在直接使用了在 256 * 256 分辨率大小,8 * 8 windows 大小下训练好的 Swin-Transformer 模型权重,载入到不同尺度的大模型下,在不同数据集上进行了测试,性能如下所示 (Parameterized position bias 这行):

每个表格中的两列表示没有 fintune 和有 fintune,可以看出如果直接对相对位置编码插值而不进行 fintune,性能下降比较严重。故在 Swin Transformer v2 中引入了对数空间连续相对位置编码 log-spaced continuous position bias,其主要目的是可以更有效地从低分辨权重迁移到高分辨率下游任务。

相比于直接应用可学习的相对位置编码,v2 中先引入了 Continuous relative position bias (CPB),

B 矩阵来自一个小型的网络,用来预测相对位置,该模块的输入依然是 Patch 间的相对位置,这个小型网络可以是一个 2 层 MLP,然后接中间采用激活函数连接。

其性能如上表的 Linear-Spaced CPB 所示,可以发现相比原先的相对位置编码性能有所提升,但是当模型尺度继续增加,图片尺寸继续扩大后性能依然会下降比较多,原因是预测目标空间是一个线性的空间。当 Windows 尺寸增大的时候,比如当载入的是 8*8 大小 windows 下训练好的模型权重,要在 16*16 大小的 windows 下进行 fine-tune,此时预测相对位置范围就会从 [-7,7] 增大到 [-15,15],整个预测范围的扩大了不少,这可能会出现网络不适应性。为此作者将预测的相对位置坐标从 linear space 改进到 log space 下,这样扩大范围就缩小了不少, 可以提供更加平滑的预测范围,这会增加稳定性,提升泛化能力,性能表如上的 Log-Spaced CPB 所示。

在 Swin Transformer 中相对位置编码矩阵 shape 和 QK矩阵计算后的矩阵一样大,其计算复杂度是 O(HW),当图片很大或者再引入 T 时间轴,那么计算量非常大, 故在 Improved MViT ,作者进行了分解设计,分成 H 轴相对位置编码,W 轴相对位置编码,然后相加,从而将复杂度降低为 O(H W)。

关于绝对位置编码和相对位置编码到底哪个是最好的,目前还没有定论,在不同的论文实验中有不同的结论,暂时来看难分胜负。但是从上面可以分析来看,在 ViT 中不管是绝对位置编码和相对位置编码,当图片大小改变时候都需要对编码向量进行插值,性能都有不同程度的下降 ( Swin Transformer v2 在一定程度上解决了)。

3.1.2.2.2 隐式位置编码

当图片尺寸改变时候,隐式位置编码可以很好地避免显式位置编码需要对对编码向量进行插值的弊端。其核心做法都是基于图片语义自适应生成位置编码

在论文 How much position information do convolutional neural networks encode? 中已经证明 CNN 不仅可以编码位置信息,而且越深的层所包含的位置信息越多,而位置信息是通过 zero-padding 透露的。既然 Conv 自带位置信息,那么可以利用这个特性来隐式的编码位置向量。大部分算法都直接借鉴了这一结论来增强位置编码,典型代表有 CPVT、PVTv2 和 CSwin Transformer 等。

CPVT 指出基于之前 CNN 分类经验,分类网络通常需要平移不变性,但是绝对位置编码会在一定程度打破这个特性,因为每个位置都会加上一个独一无二的位置编码。看起来似乎相对位置编码可以避免这个问题,其天然就可以适应不同长度输入,但是由于相对位置编码在图像分类任务中无法提供任何绝对位置信息(实际上相对位置编码也需要插值),而绝对位置信息被证明非常重要。以 DeiT-Tiny 模型为例,作者通过简单的对比实验让用户直观的感受不同位置编码的效果:

2D PRE 是指 2D 相对位置编码,Top-1@224 表示测试时候采用 224 图片大小,这个尺度和训练保持一致,Top-1@384 表示测试时候采用 384 图片大小,由于图片大小不一致,故需要对位置编码进行插值。从上表可以得出:

  • 位置编码还是很重要,不使用位置编码性能很差。
  • 2D 相对位置编码性能比其他两个差,可学习位置编码和 sin-cos 策略效果非常接近,相对来说可学习绝对位置编码效果更好一些(和其他论文结论不一致)。
  • 在需要对位置编码进行插值时候,性能都有下降。

基于上述描述,作者认为在视觉任务中一个好的位置编码应满足如下条件:

  • 模型应该具有 permutation-variant 和 translation-equivariance 特性,即对位置敏感但同时具有平移不变性。
  • 能够自然地处理变长的图片序列。
  • 能够一定程度上编码绝对位置信息。

基于这三个原则,CPVT 引入了一个带有 zero-padding 的卷积 ( kernel size k ≥ 3) 来隐式地编码位置信息,并提出了 Positional Encoding Generator (PEG) 模块,如下所示:

将输入序列 reshape 成图像空间维度,然后通过一个 kernel size 为 k ≥ 3, (k−1)/2 zero paddings 的 2D 卷积操作,最后再 reshape 成 token 序列。这个 PEG 模块因为引入了卷积,在计算位置编码时候会考虑邻近的 token,当图片尺度改变时候,这个特性可以避免性能下降问题。算法的整体结构图如下所示:

基于 CPVT 的做法,PVTv2 将 zero-padding 卷积思想引入到 FFN 模块中。

通过在常规 FFN 模块中引入 zero-padding 的逐深度卷积来引入隐式的位置编码信息(称为 Convolutional Feed-Forward)。

同样的,在 CSWin Transformer 中作者也引入了 3x3 DW 卷积来增强位置信息,结构图如下所示:

APE 是 ViT 中的绝对位置编码,CPE 是 CPVT 中的条件位置编码,其做法是和输入序列 X 相加,而 RPE 是 Swin Transformer 中所采用的相对位置编码,其是加到 QK 矩阵计算后输出中,而本文所提的 Locally-Enhanced Positional Encoding (LePE),是在自注意力计算完成后额外加上 DW 卷积值,计算量比 RPE 小。

LePE 做法对于下游密集预测任务中图片尺寸变化情况比较友好,性能下降比较少。

除了上述分析的诸多加法隐式位置编码改进, ResT 提出了另一个非常相似的,但是是乘法的改进策略,结构图如下所示:

对 Patch Embedding 后的序列应用先恢复空间结构,然后应用一个 3×3 depth-wise padding 1的卷积来提供位置注意力信息,然后通过 sigmoid 操作变成注意力权重和原始输入相乘。代码如下所示:

代码语言:javascript复制
class PA(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x 是已经恢复了空间结构的 patch embedding 
        return x * self.sigmoid(self.pa_conv(x))     

作者还指出,这个Pixel Attention (PA) 模块可以替换为任意空间注意力模块,性能优异,明显比位置编码更加灵活好用。

3.1.2.3 自注意力模块

Transformer 的最核心模块是自注意力模块,也就是我们常说的多头注意力模块,如下图所示:

注意力机制的最大优势是没有任何先验偏置,只要输入足够的数据就可以利用全局注意力学到泛化性能不错的特征。当数据量足够大的时候,注意力机制是 Transformer 模型的最大优势,但是一旦数据量不够就会变成逆势,后续很多算法改进方向都是希望能够引入部分先验偏置辅助模块,在减少对数据量的依赖情况下加快收敛,并进一步提升性能。同时注意力机制还有一个比较大的缺点:因为其全局注意力计算,当输入高分辨率图时候计算量非常巨大,这也是目前一大改进方向

简单总结,可以将目前自注意力模块分成 2 个大方向:

  1. 仅仅包括全局注意力,例如 ViT、PVT 等。
  2. 引入额外的局部注意力,例如 Swin Transformer。

如果整个 Transformer 模型不含局部注意力模块,那么其主要改进方向就是如何减少空间全局注意力的计算量,而引入额外的局部注意力自然可以很好地解决空间全局注意力计算量过大的问题,但是如果仅仅包括局部注意力,则会导致性能下降严重,因为局部注意力没有考虑窗口间的信息交互,因此引入额外的局部注意力的意思是在引入局部注意力基础上,还需要存在窗口间交互模块,这个模块可以是全局注意力模块,也可以是任何可以实现这个功能的模块。其结构图如下所示:

3.1.2.3.1 仅包括全局注意力

标准的多头注意力就是典型的空间全局注意力模块,当输入图片比较大的时候,会导致序列个数非常多,此时注意力计算就会消耗大量计算量和显存。以常规的 COCO 目标检测下游任务为例,输入图片大小一般是 800x1333,此时 Transformer 中的自注意力模块计算量和内存占用会难以承受。其改进方向可以归纳为两类:减少全局注意力计算量以及采用广义线性注意力计算方式。

(1) 减少全局注意力计算量

全局注意力计算量主要是在 QK 矩阵和 Softmax 后和 V 相乘部分,想减少这部分计算量,那自然可以采用如下策略:

  1. 降低 KV 维度,QK 计算量和 Softmax 后和 V 相乘部分计算量自然会减少。
  2. 减低 QKV 维度,主要如果 Q 长度下降了,那么代表序列输出长度改变了,在减少计算量的同时也实现了下采样功能。

(a) 降低 KV 维度

降低 KV 维度做法的典型代码是 PVT,其设计了空间 Reduction 注意力层 (SRA) ,如下所示:

其做法比较简单,核心就是通过 Spatial Reduction 模块缩减 KV 的输入序列长度,KV 是空间图片转化为 Token 后的序列,可以考虑先还原出空间结构,然后通过卷积缩减维度,再次转化为序列结构,最后再算注意力。假设 QKV shape 是完全相同,其详细计算过程如下:

  • 在暂时不考虑 batch 的情况下,KV 的 shape 是 (H'W', C),既然叫做空间维度缩减,那么肯定是作用在空间维度上,故首先利用 reshape 函数恢复空间维度变成 (H', W', C)。
  • 然后在这个 shape 下应用 kernel_size 和 stride 为指定缩放率例如 8 的二维卷积,实现空间维度缩减,变成 (H/R, W/R, C), R 是缩放倍数。
  • 然后再次反向 reshape 变成 (HW/(R平方), C),此时第一维(序列长度)就缩减了 R 平方倍数。
  • 然后采用标准的多头注意力层进行注意力加权计算,输出维度依然是 (H'W', C)。

而在 Twins 中提出了所谓的 GSA,其实就是 PVT 中的空间缩减模块。

同时基于最新进展,在 PVTV2 算法中甚至可以将 PVTv1 的 Spatial Attention 直接换成无任何学习参数的 Average Pooling 模块,也就是所谓的 Linear SRA,如下所示:

同样参考 PVT 设计,在 P2T 也提出一种改进版本的金字塔 Pool 结构,如下所示:

(b) 即为改进的 Spatial Attention 结构,对 KV 值应用不同大小的 kernel 进行池化操作,最后 cat 拼接到一起,输入到 MHSA 中进行计算,通过控制 pool 的 kernel 就可以改变 KV 的输出序列长度,从而减少计算量,同时金字塔池化结构可以进一步提升性能(不过由于其 FFN 中引入了 DW 卷积,也有一定性能提升)。

从降低 KV 空间维度角度出发,ResT 算法中也提出了一个内存高效的注意力模块 E-MSA,相比 PVT 做法更近一步,不仅仅缩减 KV 空间维度,还同时加强各个 head 之间的信息交互,如下所示:

其出发点有两个:

  • 当序列比较长或者维度比较高的时候,全局注意力计算量过大。
  • 当多头注意力计算时,各个头是按照 D 维度切分,然后独立计算最后拼接输出,各个头之间没有交互,当 X 维度较少时,性能可能不太行

基于上述两点,作者引入 DWConv 缩放 KV 的空间维度来减少全局注意力计算量,然后在 QK 点乘后引入 1x1 Conv 模块进行多头信息交互。其详细做法如下:

  1. 假设输入序列 X Shape 是 nxd,n 表示序列长度,d 表示每个序列的嵌入向量维度。
  2. 假设想将特征图下采样 sxs 倍,可以将 X 输入到 kernel 为 (s 1,s 1),stride 为 (s, s), padding 为 (s//2, s//2) 的 DW 卷积和 LN 层中,假设输出变成 (h'w', d)。
  3. 将其经过线性映射,然后在 d 维度切分成 k 个部分,分别用于 k 个头中。
  4. QK 计算点积和 Scale 后,Shape 变成 (k, n, n'),然后对该输出采用 1x1 卷积在头的 k 维度进行多个 head 之间的信息聚合。
  5. 后续是标准的注意力计算方式。

其核心代码如下所示:

代码语言:javascript复制
class Attention(nn.Module):
    def __init__(self,
                 dim=32,
                 num_heads=8,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 sr_ratio=2):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # (sr_ratio 1)x (sr_ratio 1) 的 DW 卷积
        self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio   1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim)
        self.sr_norm = nn.LayerNorm(dim)

        # 1x1 卷积
        self.transform_conv = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, stride=1)
        self.transform_norm = nn.InstanceNorm2d(self.num_heads)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        
        # 1 空间下采样
        x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
        x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
        x_ = self.sr_norm(x_)
        kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]
        
        # 2 输出维度为 (B,num_head, N, N')
        attn = (q @ k.transpose(-2, -1)) * self.scale

        # 3 在 num_head 维度进行信息聚合,加强 head 之间的联系
        attn = self.transform_conv(attn)
        attn = attn.softmax(dim=-1)
        attn = self.transform_norm(attn)

        # 4 子注意力模块标准操作
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

(b) 降低 QKV 维度

Multiscale Vision Transformers (MViT) 也考虑引入 Pool 算子来减少全局注意力计算量。MViT 主要为了视频识别而设计,因为视频任务本身就会消耗太多的显存和内存,如果不进行针对性设计,则难以实际应用。正如其名字所言,其主要是想参考 CNN 中的多尺度特性(浅层空间分辨率大通道数少,深层空间分辨率小通道数多)设计出适合 Transformer 的多尺度 ViT。自注意力模块如下所示:

相比 Transformer 自注意力模块,其主要改变是多了 Pool 模块,该模块的主要作用是通过控制其 Stride 参数来缩放输入的序列个数,而序列个数对应的是图片空间尺度 THW。以图像分类任务为例,

  • 任意维度的序列 X 输入,首先和 3 个独立的线性层得到 QKV,维度都是 (HW, D)。
  • 将QKV 恢复空间维度,变成 (H, W, D),然后经过独立的 3 个 Pool 模块,通过控制 stride 参数可以改变输出序列长度,变成 (H', W', D),设置 3 个 Pool 模块不同的 Stride 值可以实现不同大小的输出。
  • 将输入都拉伸为序列格式,然后采用自注意力计算方式输出 (H'W‘, D)。
  • 为了保证输出序列长度改变而无法直接应用残差连接,需要在 X 侧同时引入一个 Pool 模块将序列长度和维度变成一致。

由于 MViT 出色的性能,作者将该思想推广到更多的下游任务中(例如目标检测),提出了改进版本的 Imporved MViT,其重新设计的结构图如下所示:

Imporved MViT 在不同的下游任务提升显著。

(2) 广义线性注意力计算方式

基于 NLP 中 Transformer 进展,我们可以考虑将其引入到 ViT 中,典型的包括 Performer,其可以通过分解获得一个线性时间注意力机制,并重新排列矩阵乘法,以对常规注意力机制的结果进行近似,而不需要显示构建大小呈平方增长的注意力矩阵。在 T2T-ViT 算法中则直接使用了高效的 Performer。

在 NLP 领域类似的近似计算方式也有很多,由于本文主要关注 ViT 方面的改进,故这部分不在展开分析。

3.1.2.3.2 引入额外局部注意力

引入额外局部注意力的典型代表是 Swin Transformer,但是卷积模块工作方式也可以等价为局部注意力计算方式,所以从目前发展来看,主要可以分成 3 个大类:

  • 局部窗口计算模式,例如 Swin Transformer 这种局部窗口内计算。
  • 引入卷积局部归纳偏置增强,这种做法通常是引入或多或少的卷积来明确提供局部注意力功能。
  • 稀疏注意力。

结构图如下所示:

需要特别注意的是:

  1. 引入局部窗口注意力后依然要提供跨窗口信息交互模块,不可能只存在局部注意力模块,因为这样就没有局部窗口间的信息交互,性能会出现不同程度的下降,也不符合 Transformer 设计思想( Patch 内和 Patch 间信息交互)。
  2. 局部窗口计算模式和引入卷积局部归纳偏置增强的划分依据是其核心出发点和作用来划分,而不是从是否包括 Conv 模块来区分。

(1) 局部窗口计算模式

局部注意力的典型算法是 Swin Transformer,其将自注意力计算过程限制在每个提前划分的窗口内部,称为窗口注意力 Window based Self-Attention (W-MSA),相比全局计算自注意力,明显可以减少计算量,但是这种做法没法让不同窗口进行交互,此时就退化成了 CNN,所以作者又提出移位窗口注意力模块 Shifted window based Self-Attention (SW-MSA),示意图如下所示,具体是将窗口进行右下移位,此时窗口数和窗口的空间切分方式就不一样了,然后将 W-MSA 和 SW-MSA 在不同 stage 之间交替使用,即可实现窗口内局部注意力计算和跨窗口的局部注意力计算,同时其要求 stage 个数必须是偶数。

大概流程为:

  1. 假设第 L 层的输入序列 Shape 是 (N, C),而 N 实际上是 (H, W) 拉伸而来。
  2. 将上述序列还原为图片维度即(H, W, C), 假设指定每个窗口大小是 7x7,则可以将上述图片划分为 HW/49 个块,然后对每个块单独进行自注意力计算(具体实现上可以矩阵并行),这样就将整个图像的全局自注意力计算限制在了窗口内部即 W-MSA 过程。
  3. 为了加强窗口间的信息交流,在 L 1 层需要将 W-MSA 换成 SW-MSA,具体是将 L 层的输出序列进行 shift 移位操作,如上图所示,从 4 个窗口就变成了 9 个窗口,此时移位后的窗口包含了原本相邻窗口的元素,有点像窗口重组了,如果在这 9 个窗口内再次计算 W-MSA 其输出就已经包括了 L 层窗口间的交互信息了。

上述只是原理概述,实际上为了保证上述操作非常高效,作者对代码进行了非常多的优化,相对来说是比较复杂的。值得注意的是 Swin Transformer 相比其他算法(例如 PVT )非常高效,因为整个算法中始终不存在全局注意力计算模块( SW-MSA 起到类似功能),当图片分辨率非常高的时候,也依然是线性复杂度的,这是其突出优点。凭借其局部窗口注意力机制,刷新了很多下游任务的 SOTA,影响非常深远。

在 Swin Transformer v2 中探讨了模型尺度和输入图片增大时候,整个架构的适应性和性能。在大模型实验中作者观察到某些 block 或者 head 中的 attention map 会被某些特征主导,产生这个的原因是原始 self-attention 中对于两两特征之间的相似度衡量是用的内积,可能会出现某些特征 pair 内积过大。为了改善这个问题,作者将内积相似度替换为了余弦相似度,因为余弦函数本身的取值范围本身就相当于是被归一化后的结果,可以改善因为些特征 pair 内积过大,主导了 attention 的情况,结构图如下所示:

Swin Transformer 算法在解决图片尺度增加带来的巨大计算量问题上有不错的解决方案,但是 SW-MSA 这个结构被后续诸多文章吐槽,主要包括:

  1. 为了能够高效计算,SW-MSA 实现过于复杂。
  2. SW-MSA 对 CPU 设备不友好,难以部署。
  3. 或许有更简单更优雅的跨窗口交互机制。

基于这三个问题,后续学者提出了大量的针对性改进,可以归纳为两个方向:

  1. 抛弃 SW-MSA,依然需要全局注意力计算模块,意思是不再需要 SW-MSA,跨窗口交互功能由全局注意力计算模块代替,当然这个全局注意力模块是带有减少计算量功能的。
  2. 抛弃 SW-MSA,跨窗口信息交互由特定模块提供,这个特定模块就是改进论文所提出的模块。

(a) 抛弃 SW-MSA,依然需要全局注意力计算模块

Imporved MViT 在改进的 Pool Attention 基础上,参考 Swin Transformer 在不同 stage 间混合局部注意力 W-MSA 和 SW-MSA 设计,提出 HSwin 结构,在 4 个 stage 中的最后三个 stage 的最后一个 block 用全局注意力 Pool Attention 模块(具体间 3.1.2.3.1 小节),其余 stage 的 block 使用 W-MSA ,实验表明这种设计比 Swin Transformer 性能更强,也更简单。

同样 Twins 也借鉴了 W-MSA 做法,只不过由于其位置编码是隐式生成,故不再需要相对位置编码,而且 SW-MSA 这种 Shift 算子不好部署,所以作者的做法是在每个 Encoder 中分别嵌入 Locally-grouped self-attention (LSA) 模块即不带相对位置编码的 W-MSA 以及 GSA 模块GSA 就是 PVT 中使用的带空间缩减的全局自注意力模块,通过 LSA 计算局部窗口内的注意力,然后通过全局自注意力模块 GSA 计算窗口间的注意力,结构图如下所示:

(b) 抛弃 SW-MSA,跨窗口信息交互由特定模块提供

参考 CNN 网络设计思想,可以设计跨窗口信息交互模块,典型的论文包括 MSG-T 、Glance-and-Gaze Transformer 和 Shuffle Transformer。

MSG-Transformer 基于 W-MSA,通过引入一个 MSG Token 来加强局部窗口之间的信息交互即在每个窗口内额外引入一个 MSG Token,该 Token 主要作用就是进行窗口间信息传递,所设计的模块优点包括对 CPU 设备友好,推理速度比 SWin 快,性能也更好一些。结构图如下所示:

  1. 假设将图片或者输入序列划分为 4x4 个窗口,在每个窗口内部再划分为 2x2 个 shuffle 区域。
  2. 在每个窗口中会额外拼接一个可学习的 MSG Token (三角形),故一共需要拼接 2x2 个可学习的 MSG Token。
  3. 将拼接后的所有 token 经过 layer norm、Swin Transformer 中的 W-MSA 和残差连接后,同一个窗口内的 token 会进行注意力计算,从而进行窗口内信息融合。
  4. 单独对 2x2 个 MSG Token 进行 shuffle 操作,交互 2x2 个 token 信息。
  5. 然后对输出再次进行 layer norm、Channel MLP 和残差连接后输出即可。

在第 3 步的 W-MSA 计算中,可以认为同一个窗口内会进行信息流通,从而 2x2 个 MSG Token 都已经融合了对应窗口内的信息,然后经过第 4 步骤 MSG Token 交换后就实现了局部窗口间信息的交互。MSG Token 信息交互模块完成了 Swin Transformer 中 SW-MSA ,相比 SW-MSA 算子,不管是计算复杂度还是实现难度都非常小。Shuffle 计算过程和 ShuffleNet 做法完全一样,如下所示:

将 Swin Transformer 中的 block 全部换成 MSG-Transformer block ,通过实验验证了本结构的优异性。

Shuffle Transformer 也是从效率角度对 Swin Transformer 的 SW-MSA 进行改进,正如其名字,其是通过 Shuffle 操作来加强窗口间信息交流,而不再需要 SW-MSA,由于其做法和 ShuffleNet 一致就不再详细说明,核心思想如下所示 (c):

将 Swin Transformer 中的 SW-MSA 全部换成 Shuffle W-MSA,在此基础上还引入了额外的 NWC 模块,其是一个 DW Conv,其 kernel size 和 window size 一样,用于增强邻近窗口信息交互,Shuffle Transformer 的 block 结构如下所示:

在 ImageNet 数据集上,同等条件上 Shuffle Transformer相比 Swin 有明显提升,在 COCO 数据集上,基于 Mask R-CNN,Shuffle Transformer 和 Swin 性能不相上下。

因为 Swin Transformer 不存在 NWC 模块,作者也进行了相应的对比实验:

这也进一步验证了引入适当的 CNN 局部算子可以在几乎不增加计算量的前提下显著提升性能。

MSG-Transformer 和 Shuffle Transformer 都是通过直接替换 SW-MSA 模块来实现的,Glance-and-Gaze Transformer (GG-Transformer) 则认为没有必要分成两个独立的模块,只需要通过同一个模块的两个分支融合就可以同时实现 W-MSA 和 SW-MSA 功能。结构图如下所示:

其提出一种新的局部窗口注意力计算机制,相比常规的近邻划分窗口,其采用了自适应空洞率窗口划分方式,上图是假设空洞率是 2 即每隔 1 个位置,这样就可以将图片划分为 4 个窗口,由于其采样划分方式会横跨多个像素位置,相比 Swin Transofrmer 划分方式具有更大的感受野,不断 stage 堆叠就可以实现全局感受野。在 Glance 分支中采用 MSA 局部窗口计算方法计算局部注意力,同时为了增强窗口之间的交互,其将 V 值还原为原先划分模式,然后应用 depth-wise conv 来提取局部信息,再通过自适应空洞划分操作的逆操作还原,再加上 Attention 后的特征。

Glance 分支用于在划分窗口内单独计算窗口内的局部注意力,由于其自适应空洞率窗口划分方式,使其能够具备全局注意力提取能力,而 Gaze分支用于在划分的窗口间进行信息融合,具备窗口间局部特征提取能力。将 Swin Transformer 中的 block 全部换成 GG-Transformer block ,通过实验验证了其性能优于 Swin Transformer 。

在改进 Swin Transformer 的窗口注意力计算方式这方面,CSWin Transformer 相比其余改进更加独特,其提出了十字架形状的局部窗口划分方式,如下图所示:

假设一共将图片划分成了 9 个窗口,本文所提注意力的计算只会涉及到上下左右中 5 个窗口,同时为了进一步减少计算量,又分成 horizontal stripes self-attention 和 vertical stripes self-attention,每个自注意力模块都只涉及到其中 3 个窗口,这种计算方式同时兼顾了局部窗口计算和跨窗口计算,一步到位。所谓 horizontal stripes self-attention 是指沿着 H 维度将 Tokens 分成水平条状 windows,假设一共包括 k 个头,则前 k/2 个头用于计算 horizontal stripes self-attention,后面 k/2 个头用于计算 vertical stripes self-attention。两组self-attention是并行的,计算完成后将 Tokens 的特征 concat 在一起,这样就构成了CSW self-attention,最终效果就是在十字形窗口内做 Attention,可以看到 CSW self-attention 的感受野要比常规的 Window Attention 的感受野更大。可以通过控制每个条纹的宽度来控制自注意力模块的复杂度,默认 4 个 stage 的条纹宽度分别设为 1, 2, 7, 7(图片空间维度比较大的时候采用较小的条纹宽度,减少计算量)。

(2) 引入卷积的局部归纳偏置能力

上述都是属于 Swin Transformer 改进,在引入卷积局部归纳偏置增强方面,典型算法为 ViTAE 和 ELSA,ViTAE 结构图如下所示:

其包括两个核心模块:reduction cell (RC) 和 normal cell (NC)。RC 用于对输入图像进行下采样并将其嵌入到具有丰富多尺度上下文的 token 中,而 NC 旨在对 token 序列中的局部性和全局依赖性进行联合建模,可以看到,这两种类型的结构共享一个简单的基本结构。

对于 RC 模块,分成两个分支,第一条分支首先将特征图输入到不同空洞率并行的卷积中,提取多尺度特征的同时也减少分辨率,输出特征图拼接 GeLU 激活,然后输入到注意力模块中,第二条分支是纯粹的 Conv 局部特征提取,用于加强局部归纳偏置,两个分支内容相加,然后输入到 FFN 模块中。

对于 NC 模块,类似分成两个分支,第一条是注意力分支,第二条是 Conv 局部特征提取,用于加强局部归纳偏置,两个分支内容相加,然后输入到 FFN 模块中。

基于上述模块,构建了两个典型网络,如下所示:

至于为何要如此设置以及各个模块的前后位置,作者进行了大量的实验研究:

ELSA: Enhanced Local Self-Attention for Vision Transformer 基于一个现状:Swin Transformer 中所提的局部自注意力(LSA)的性能与卷积不相上下,甚至不如动态过滤器。如果是这样,那么 LSA 的重要性就值得怀疑了。最近也有很多学者发现了个奇怪的现象,例如 Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight 中深入分析了注意力和 Conv 的关系,特别是 DW Conv,但是大部分都没有深入探讨 LSA 性能如此平庸的具体原因,本文从这个方面入手,并提出了针对性改进。

奇怪现象

作者以 Swin Tiny 版本为例,将其中的局部窗口注意力模块 LSA 替换为 DW Conv、decoupled dynamic filter (DDF),从上图可以看出 DWConv 和 DDF 性能都比 LSA 强的,特别是 DW Conv,在参数量和 FLOPs 更小的情况下性能会比 Swin Transformer 高。

原因分析

作者试图从 2 个角度来统一分析 LSA、DWConv 和 DDF,分别是通道数设置和空间处理方式 spatial processing。

通道数设置 channel setting 可以用于统一建模 LSA、DWConv 在通道数上的差异。DWConv 是逐通道(深度)上计算 Conv,也就是说在不同的通道上应用了不同的滤波器参数,而 DDF 和 LSA 是将通道数分组,然后在同一个组内采用相同的滤波器参数,LSA 的通道数分组实际上就是常说的多头机制即DwConv 可以视作一种特殊的多头策略。如果将 DWConv 的通道数设置为 LSA 的头个数,那么通道数设置原则上就是相同效果了。对比结果如下所示:

从上述曲线可以看出:

  • 在相同通道配置下(例如1x、2x,这里是特指4 个 stage 中 head 个数,典型的 1x 是 3, 6, 12, 24),DwConv 版本仍与 LSA 版本具有相似性能。
  • 从上述 1x, 2x 结果可以看出,通常来说 DWConv 的通道数肯定比 LSA 的 Head 个数多,这可能是 DWConv 性能比 LSA 高的原因,但是当设置为相同的 Head 个数时,DWConv 性能甚至比 LSA 类型更好一些。,这说明通道配置并非导致前述现象的主要原因。
  • 当 Head 个数配置大于 1x 时,LSA 性能反而下降,但是 DWConv 性能确实能够提升,这说明两个问题:(1) LSA 中直接提升头数并不能改善通道容量与性能,但是 DWConv 可以;(2) LSA 中需要一种新的策略以进一步提升通道容量和性能

空间处理方式 spatial processing 是指如何得到滤波器或者注意力图并对空域进行信息聚合。DwConv采用静态的滤波器即一旦训练完成,不管输入啥图片都是采用固定 kernel,而其他两者则采用动态滤波器。为了方便统一建模分析,作者采用统一的表示式来说明,如下所示:

Conv 和 DwConv 计算公式为:

当仅仅使用 rbj-i 相对位置偏置,Norm 设置为恒等变换, Θ 表示滑动窗口,计算方式则可以表示 Conv 和 DwConv,j-i 表示相对偏移。

Dynamic filters 计算公式为:

W 是 1x1 卷积,当仅仅使用 qirkj-i ( rkj-i 表示 k 的相对位置嵌入向量 ), Norm 设置为恒等变换则可以表示 Dynamic filters,注意 Dynamic filters 有非常多种做法,以上写的是其中一种。

LSA 计算公式为:

我们可以很容易地将公式 1 退化为 LSA 模式, Ω 表示非重叠局部窗口计算方式。

自从可以将 DwConv 、Dynamic filters 和 LSA 统一起来,并且将其分成三个核心不同部分:参数形式 parameterization, 规范化方式 normalization 和滤波器应用方式 filter application。

参数形式 parameterization

从上表可以看出:

  • 动态滤波器的参数策略要比标准 LSA 策略具有更优的性能(Net2 vs Net1)
  • 动态滤波器变种策略 (Net6) 具有与 SwinT 相当的性能
  • LSA 参数策略与动态滤波器参数策略的组合 (Net7) 可以进一步提升模型性能

规范化方式 Normalization

  • 当采用 Net7 的参数形式组合 Identity 时,模型训练崩溃
  • 相比 FilterNorm,Softmax 规范化具有更优的性能
  • 规范化方式并非 LSA 平庸的原因

滤波器应用方式 filter application

当将滤波器采用滑动窗口形式时,Net6 与 Net7 均得到了显著性能提升 。这意味着:近邻处理方式是空域处理的关键

基于上述实验,使 LSA 变平庸的因素可以分为两个因素:

  • 相对位置嵌入是影响性能的一个关键因素。
  • 另外一个关键因素是滤波器使用方式,即滑动窗口 vs 非重叠窗口。

DwConv 能够与 LSA 性能相媲美的原因在于:它采用了滑动窗口处理机制。当其采用非重叠窗口机制时,性能明显弱于 LSA(见 Table1 中的 Net4)。

动态滤波器性能优于 LSA 的原因在于相对位置嵌入与近邻滤波器使用方式。两者的集成 (Net7) 取得了最佳的性能。

对比非重叠局部窗口与滑动窗口,局部重叠的峰值性能要弱于滑动窗口。局部窗口的一个缺点在于:窗口间缺乏信息交互,限制了其性能;而滑动窗口的缺陷在于低吞吐量。那么,如何避免点乘同时保持高性能就成了新的挑战

基于上述分析,作者提出一种增强的局部自注意力模块 ELSA,性能可以超越了 SwinT 中的 LSA 与动态滤波器。

将该模块替换掉 Swin Transformer 中的前三个 stage 中的 LSA 模块即可。ELSA 的关键模块为Hadamard 注意力与 Ghost 头模块,表达式如下所示:

相比 Dot product,Hadamard product ( Pytorch 中的 A*B )可以有效地提取高阶信息,而 ghost 模块是受启发于 GhostNet,可以在有限的容量下提取更丰富的通道特征。作者在分类、目标检测和语义分割任务上都进行了验证。

ELSA 的解决办法看起来过于复杂,整个分析过程也感觉有点复杂,可能会存在更简单的改进策略。

(3) 稀疏注意力(局部紧密相关和远程稀疏)

和 Performer 一样,NLP 领域内也有很多对自注意力模块的改进中是引入局部注意力的,典型的例如 Sparse Transformers,其出发点是实际场景中或许用不到真正的全局注意力,只要提供稍微等价的全局效果就行,不一定真的理论上达到全局。基于此可以实现局部紧密相关和远程稀疏的 Sparse Transformers,后续改进版本也非常多,例如 LongFormer。同样的,本文对 NLP 领域发展不展开描述。

3.1.2.4 FFN 模块

FFN 模块比较简单,主要是进行特征变换,ViT 中代码如下所示:

代码语言:javascript复制
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

FFN 的改进方向和前面类似,也是希望引入 CNN 局部特征提取能力,加速收敛、提升性能,主要可以归纳为如下所示:

在引入 Conv 增强局部信息特征信息提取方面,也有不少论文进行了尝试。

LocalViT 中在其他模块都不改动情况下,仅仅在 FFN 中引入 1x1 卷积和带 padding 的 3x3 DW 卷积来增强局部特征提取能力,实验结果能够带来不少的性能提升,如下所示:

(a) 是简单的卷积 forward 模块,(b) 是反转残差块,(c) 是本文所提结构,仅仅需要通过替换 FFN 中间的 MLP 层为卷积层。同样采用类似做法的有 PVTv2 和 HRFormer 等等。

CeiT 中也提出了非常类似的结构,其出发点也是希望引入 CNN 来加强局部特征能力,结构如下所示:

主要包括线性投影、恢复空间结构、3x3 DW 卷积、flatten 和线性投影。

在 3.1.2.2 位置编码模块中说过, PVTv2 在 FFN 中引入了零填充的逐深度通道卷积来自动感知位置信息,从而可以省掉位置编码。从 FFN 角度来看,除了有省掉位置编码的好处外, 还能够加强局部特征提取能力,相比 LocalViT 和 CeiT ,直接抛弃位置编码的做法更加彻底,更加合理

3.1.2.5 Norm 位置改动

Norm 通常是 Layer Norm,按照该模型放在自注意力和 FFN 模块的前面还是后面,可以分成 pre norm 和 post norm 方式,如下所示:

绝大部分模型例如 ViT、Swin Transformer 等都是 pre norm 方式,如下所示:

代码语言:javascript复制
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x)   x
            x = ff(x)   x

但是 Swin Transformer v2 却改成了 post norm。Swin Transformer V2 主要探讨如何 scale up 视觉模型,并最终在 4 个数据集上重新刷到了新的 SOTA。将视觉 Transformer 的模型参数调大,作者认为会面临以下几个挑战:

  • 增大视觉模型时候可能会带来很大的训练不稳定性
  • 在很多需要高分辨率的下游任务上,还没有很好地探索出对低分辨率下训练好的模型迁移到更大scale 模型上的方法
  • GPU 内存代价太大

作者发现,将 Swin Transformer 模型从 small size 增大到 large size 后,网络深层的激活值会变得很大,与浅层特征的激活值有很大的 gap,如下图所示,可以看出来随着模型 scale 的增大,这个现象会变得很严重。

作者发现使用 post-norm 操作后,上面所观察到的问题可以得到很明显的改善,并且为了更进一步稳定 largest Swin V2 的训练,在每 6 个 transformer block 后还额外加了一层 layer normalization。也就是说在大模型 Transformer 中使用 post Norm 能够更加稳定训练过程。

3.1.2.6 分类预测头模块

在 ViT 中通过追加额外一个 Class Token,将该 Token 对应的编码器输出输入到 MLP 分类头(实际上是一个线性投影层)进行分类,如下所示:

代码语言:javascript复制



self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
        
# x 是 patch embedding 输出        
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)  

...   
  
# x 是最后一个 Transformer 输出     
self.mlp_head(x[:, 0])     

为何不能和我们常规的图像分类一样,直接聚合所有特征,而需要单独引入一个 Class Token ?这个问题自然有人进一步探索,经过简单总结,可以归纳为如下结构图:

从目前来看主要包括 2 种做法:Class Token 和常规分类一样直接聚合特征。

Class Token 方面的改进典型代码是 CeiT ,作者的出发点是 ViT 仅仅使用最后一个 Transformer 输出做分类,但是实验观察不同的 Transformer 层输出特性不一样,如果能够聚合多个输出层的特征进行分类,也许会更好,具体做法如下所示:

首先收集不同 Transformer 层的 Class Token 对应的输出,然后通过一个 Layer-wise Class-token Attention (LCA) 模块来通过注意力自适应聚合所需要的 token。LCA 可以看出一个简单的 Transformer 模块,其包括多头注意力模块 MSA 和前向网络 FFN,在聚合信息后,取输出序列的最后一个输出进行分类训练。为了节省计算量, MSA 只计算最后一个 Class Token 和其他 Class Token 之间的单向注意力,而没有必要算全局。

随着视觉 Transformer 的快速发展,后人发现其实直接采用 Avg Pool 聚合特征性能也很不错,甚至会更好,典型的代表是 Swin Transformer 和 CPVT-GAP,这种做法无需引入额外的 Class Token,和我们 CNN 中常用的分类套路一致,如下所示:

3.1.2.7 其他

这里说的其他包括两个部分内容,如下所示:

3.1.2.7.1 多尺度特征图输出

在 CNN 时代的各种下游任务(例如目标检测、语义分割)中,多分辨率多尺度特征已经被广泛证明非常重要。不同尺度特征可以提供不同的感受野,适合提取不同物体尺度的特征。然而 ViT 仅仅是为图像分类而设计,无法很好地应用于下游任务,这严重制约了视觉 Transformer 的广泛应用,故迫切需要一种能够类似 ResNet 在不同 stage 输出不同尺度的金字塔特征做法。

ViT 要输出多尺度特征图,最常见做法是 Patch Merging。Patch Mergeing 含义是对不同窗口的 Patch 进行合并,在目前主流的 PVT、Twins、Swin Transformer 和 ResT 中都有广泛的应用,以 PVT 为例详细说明,结构图如下所示:

假设图片大小是 (H, W, 3),暂时不考虑 batch。

  1. 考虑将图片切割为 HW/(4X4) 个块,每个块像素大小是 4x4x3, 此处 stride=4。
  2. 将每个 4x4x3 像素块展开,变成 1 维向量,然后经过线性投影层,输出维度变成 C1,此时特征图 shape 是 (HW/(4X4), C1) 即每个像素块现在变成了长度为 C1 的向量,这两个步骤合并称为 Patch Embedding。
  3. 将上一步输出序列和位置编码相加,输入到编码器中,输出序列长度不变。
  4. 将这个输出序列恢复成空间结构,其 shape 是 (H/4, W/4, C1),此时特征图相比原始图片就下采样了 4x4 倍。
  5. 在下一个 stage 中改变 stride 数目,然后重复 1-4 步骤就又可以缩减对应 sxs 倍,假设设置 4 个 stage 的 stride 为 [4, 2, 2, 2],那么 4 个 stage 输出的 stride 就是 [4, 8, 16, 32],这个就和 ResNet 输出 stride 完全对齐。

CSWin Transformer 则更加彻底,直接将上述 stride 下采样过程替换为 kernel 为 3x3,stride为 2 的卷积,做法和 ResNet 中下采样一致。我们将上述改变 stride 导致序列长度变少从而缩减空间特征图的过程统称为 Patch Merging。

除了上述这种相对朴素的做法,还有一些其他做法。例如 MViT ,其不存在专门的 Patch Merging 模块,而是在注意力模块中同时嵌入下采样功能,如下所示:

只要在每个 stage 中改进 Pool 模块的 stride 就可以控制实现 ResNet 一样的多尺度输出,从而实现多分辨率金字塔特征输出。

ViTAE 则直接引入空洞卷积,然后设置不同的空洞率来实现下采样从而输出不同尺度的特征图。

3.1.2.7.2 训练深层 Transformer

探讨如何训练更深的 Transformer 典型算法是 CaiT 和 DeepViT。前面的诸多 ViT 改进都是在编码层为 6 的基础上进行设计的,是否可以类似 CNN 分类网络设计更深的 Transformer,性能也能出现一致性提升,不至于过早饱和?

(1) CaiT

在 CaiT 算法中,作者从 Transformer 架构和优化关系会相互影响相互作用角度出发进行探讨。在 ResNet 中,作者也说过较深的网络能不能顺利的优化才是其成功的关键,而不是在于特征提取能力多强,对应到 ViT 中依然是首先要考虑如何更容易优化,这会涉及到 Normalize 选择、权重初始化、残差块初始化等模块,不能小看这些模块,一个良好的初始化设计都可能避免深模型的过早饱和问题。

除了上述影响,作者发现 ViT 的 class token 似乎有点不合理,ViT 在一开始就将 class token 和 patch embedding 输出序列拼接,然后输入给自注意力模块,这个自注意力模块需要同时完成两个任务:

  • 引导 attention 过程,帮助得到 attention map。
  • token 最后输入到 classifier 中,完成分类任务。

也就是说,将 class token 和 patch embedding 过早融合,但是两者优化方向不一样,可能存在矛盾。同时说到更顺利的优化,Normalize 的重要性应该是第一个能想到的。基于这两个出发点,作者进行了如下改进:

  • 提出 LayerScale,更合理的 Norm 策略使深层 Transformer 易于收敛,并提高精度。
  • 提出 class-attention layers,class token 和 patch embedding 在最后融合,并且通过 CA 模块来更加高效地将 patch embedding 信息融合到 class embedding 中,从而提升性能。

结构图如下所示:

(a) LayerScale

  • (a) 是 ViT 做法即先进行 Layer Normalization,再进行 Self-attention 或者 FFN,然后结果与 block 输入相加。
  • (b) 是 ReZero、Skipinit 和 Fixup 算法的做法,引入了一个可学习的参数 alpha 作用在 residual block 的输出,并移除了 Layer Normalization 操作,这个可学习参数初始化可以从 0 开始,也可以从 1 开始,实验表示都无法很好地解决深层 Transformer 优化问题。
  • (c) 是一个组合做法,实验发现这种组合做法有一定效果。
  • (d) 是本文所提 LayerScale,效果最好。

LayerScale 的做法是保留 Layer Normalization,并对 Self-attention 或者 FFN 的输出乘上一个对角矩阵,由于其对角矩阵,这个实际上就等效于通道注意力(对不同通道乘上不同的系数),这些系数的设置比较有讲究,在 18 层之前,它们初始化为 0.1,若网络更深,则在 24 层之前初始化为 10^-5,若网络更深,则在之后更深的网络中初始化为 10^-6,这样设计的原因是希望越深的 block 在一开始的时候更接近恒等映射,在训练的过程中逐渐地学习到模块自身所需要的特征,这个理论和 ResNet 中初始化设计非常类似。

(b) Class-Attention Layers

从上述结构图中可以看出, class token 是在后面融合,前面的 Transformer 层专注于 patch embedding 信息提取。CA 模块做法为:

Q 仅仅来自 class token,而 KV 来自 z=[class token, patch embedding],上述做法是常规的交叉注意力计算方式。作者实验发现 CA layer 使用 2 层就够了,前面依旧是 N 个正常的 Transformer Block。如果直接看代码可能会更清晰一些:

代码语言:javascript复制
def forward_features(self, x):
    B = x.shape[0]
    # x 是图片,经过 patch embed 变成序列
    x = self.patch_embed(x)
        
    # n 个 transformer 模块
    x = x   self.pos_embed
    x = self.pos_drop(x)
    
    # SA FFN
    for i , blk in enumerate(self.blocks):
        x = blk(x)
        
    cls_tokens = self.cls_token.expand(B, -1, -1) 
    # CA FFN    
    for i , blk in enumerate(self.blocks_token_only):
        cls_tokens = blk(x,cls_tokens)
        
    x = torch.cat((cls_tokens, x), dim=1)   
    x = self.norm(x)
    return x[:, 0]

上述就是 CaiT 算法主要改进,但是其实要想成功训练出来还需要诸多 trick,作者在论文中也贴了大量对比实验,有兴趣的建议直接阅读原文,本文仅仅描述了最重要的改进而已。更贴心的是,作者还列出了 DeiT-S 到 CaiT-36 的性能提升 trick 表,这个表格同样重要。

(2) DeepViT

CaiT 算法是从整体架构和优化联系角度入手,而 DeepViT 不一样,他通过分析得出深层 Transformer 性能饱和的原因是注意力崩塌,即深层的 Transformer 学到的 attention 非常相似。这意味着随着 ViT 的层次加深,self-attention 模块在生成不同注意力以捕获多样性特征方面变得低效。如下图所示,上面一行是 ViT 训练结果,随着模型深入,自注意力层学到的自注意力权重图趋向均匀且都非常类似,下面一类是作者改进后的可视化结果,根据多样性。

DeepViT 中作者也做了很多前期实验,来验证注意力崩塌现象。那么现在核心问题就是要使得各个注意力层所提取的注意力权重图根据多样性,独特性。

为了解决这个问题,作者提出了 2 种解决办法:

(a) 增加 self-attention 模块的嵌入维度

作者认为,增加嵌入维度,可以增加表征容量从而避免注意力崩塌。实验结果如上所示,随着 Embedding Dimension 的增长,相似的 Block 的数量在下降, Acc 在上升,说明这个策略实际上确实是有效的。但增加 Embedding Dimension 也会显著增加计算成本,带来的性能改进往往会减少,且需要更大的数据量来训练,增加了过拟合的风险,这个也不是啥实质性的解决办法。为此作者提出了第二个改进 Re-attention。

(b) Re-attention

只需要替换 Self-Attention 为 Re-Attention 即可,其结构图如下右图所示:

作者通过实验观察到同一个层的不同 head 之间的相似度比较小,这表明来自同一自注意力层的不同head 关注到了输入 token 的不同方面,基于此作者提出可以把不同的 head 的输出信息融合,然后利用它们再生出一个新的 attention map。具体为引入一个可学习 Linear Transformation,在注意力图生成后乘上 Linear Transformation,从而生成一个新的 attention map,然后进行 Norm 后乘上 V。

因为我们目的是加强同一个自注意力层间不同 head (通常是8,也就是 Linear Transformation 矩阵大小是 8x8) 之间的交流,所以这个可学习 Linear Transformation 是作用在 head 维度,同时 Norm 是 BatchNorm 而不是传统的 LayerNorm。

CaiT 和 DeepViT 都是关注深层 Transformer 出现的过早饱和问题,不过关注的地方不一样,解决办法也完全不同,但是效果类似,这或许说明还有很大的改进空间。

4. 总结

ViT 的核心在于 Attention,但是整个架构也包括多个组件,每个组件都比较关键,有诸多学者对多个组件进行了改进。我们可以简单将 ViT 结构分成 6 个部分:

  1. Token 模块,其中可以分成 Image to Token 模块 和 Token to Token 模块,Image to Token 将图片转化为 Token,通常可以分成非重叠 Patch Embedding 和重叠 Patch Embedding,而 Token to Token 用于各个 Transformer 模块间传递 Token,大部分方案都和 Image to Token 做法一样即 Patch Embedding,后续也有论文提出动态窗口划分方式,本质上是利用了图片级别的语义自动生成最有代表性的采样窗口。
  2. 位置编码模块,其中可以分成显式位置编码和隐式位置编码,显式位置编码表示需要手动设置位置编码,包括绝对位置编码和相对位置编码,而隐式位置编码一般是指的利用网络生成自适应内容的位置编码向量,其提出的主要目的是为了解决显式位置编码中所遇到的当图片尺寸变化时位置编码插值带来的性能下降的问题。
  3. 注意力模块,早期的自注意力模块都是全局注意力,计算量巨大,因此在图片领域会针对性设计减少全局注意力,典型做法是降低 KV 空间维度,但是这种做法没有解决根本问题,因此 Swin Transformer 中提出了局部窗口自注意力层,自注意力计算仅仅在每个窗口内单独计算,不再存在上述问题。
  4. FFN 模块,其改进方向大部分是引入 DW 卷积增强局部特征提取能力,实验也证明了其高效性。
  5. Normalization 模块位置,一般是 pre norm。
  6. 分类预测模块,通常有两种做法,额外引入 Class Token 和采用常规分类做法引入全局池化模块进行信息聚合。

随着研究的不断深入,大家发现 Attention 可能不是最重要的,进而提出了 MLP-based 和 ConvMixer-based 类算法,这些算法都是为了说明自注意力模块可以采用 MLP 或者 Conv 层代替,这说明 Transformer 的成功可能来自整个架构设计。MLP-based 和 ConvMixer-based 部分将会在下一篇文章中进行说明。

0 人点赞