Self-Attention真的是必要的吗?微软&中科大提出Sparse MLP,降低计算量的同时提升性能!

2021-09-27 10:00:07 浏览数 (1)

写在前面

Transformer由于其强大的建模能力,目前在计算机视觉领域占据了重要的地位。在这项工作中,作者探究了Transformer的自注意(Self-Attention)模块是否是其实现图像识别SOTA性能的关键 。为此,作者基于现有的基于MLP的视觉模型,建立了一个无注意力网络sMLPNet。

具体来说,作者将以往工作中用于token混合的MLP模块替换为一个稀疏MLP(sMLP)模块。对于二维图像token,sMLP沿轴向(横向或者纵向)应用一维MLP,参数在行、列维度共享。通过稀疏连接权重共享 ,sMLP模块显著降低了模型参数的数量和计算复杂度,避免了MLP模型的内在问题(如过拟合、参数量大、计算量大)。

当仅在ImageNet-1K数据集上训练时,sMLPNet在只有24M参数下达到81.9%的Top-1精度,比相同模型大小约束下的大多数CNN和视觉Transformer要好得多。当扩展到66M参数时,sMLPNet达到了83.4%的Top-1精度,这与SOTA的 Swin Transformer相当。

1. 论文和代码地址

Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?

代码语言:javascript复制
论文地址:https://arxiv.org/abs/2109.05422
代码地址:未开源
sMLP Block复现代码:https://github.com/xmu-xiaoma666/External-Attention-pytorch#5-sMLP-Usage

2. Motivation

自AlexNet提出以来,卷积神经网络(CNN)一直是计算机视觉的主导范式。随着Vision Transformer的提出,这种情况发生了改变。ViT将一个图像被划分为不重叠的patch,并用线性层将这些patch转换为token,然后输入到Transformer中进行处理。

Transformer编码器由多头自注意网络(Multi-Head Self-Attention)和前馈网络(FFN)组成,来实现空间信息混合和通道信息混合。当在一个非常大的数据上进行预训练时,ViT在图像识别任务上表现得非常好。接着DeiT进一步证明了只在ImageNet-1K上训练时,通过适当的数据增强和正则化技术,无卷积的Vision Transformer也可以实现SOTA的图像识别精度。

目前,无卷积的Vision Transformer主要存在两个核心的思想:首先,全局依赖性建模很重要 。不仅如此,它甚至可以取代卷积操作的局部建模。第二,自注意很重要 。尽管ViT和DeiT表现良好,学术界并没有完全接受这两种观点。

一方面,研究人员挑战了用全局建模代替局部建模的必要性既然局部偏置在自然图像中是有效的,为什么要通过全局自注意模块来学习它,而不是直接将它注入到网络中呢 ?此外,全局自注意对于输入token的数量具有二次计算复杂度。因此,网络结构不有利于高分辨率输入,对金字塔结构并不友好。

基于这一点,Swin Transformer通过限制局部窗口内的自注意操作,将局部偏置注入回网络中。这种设置还控制了计算的复杂度,并允许使用金字塔结构和多阶段处理。Swin Transformer的优越性能表明了局部偏置和多阶段处理的有效性。

另一方面,研究人员也挑战了Self-Attention的必要性 。MLP-Mixer也建模了全局依赖关系,但它采用了一个MLP块,而不是一个自注意模块来实现。MLP-Mixer的整体架构与ViT相似。输入图像被分成多个patch,然后线性层将patch映射到token中。该编码器包含用于空间混合和通道混合的交替层。

唯一的主要区别是,空间混合模块是由一个MLP块实现的。MLP-Mixer继承了ViT的所有缺点,且由于参数数量过多,容易发生过拟合 。因此,MLP-Mixer和SOTA模型的性能还是存在一定差距,尤其是在不预训练的情况下。因此,作者在本文中探究了:在解决了所有的缺陷后,一个无注意力的网络是否有可能实现在图像识别上的SOTA性能?

因此,作者设计了一个无注意力的网络,称为sMLPNet,它只使用卷积和MLP作为构建块。sMLPNet采用了与ViT和MLP-Mixer类似的体系结构,且通道混合模块与他们完全相同。在每个token混合模块中,采用深度卷积来利用局部偏置,并使用改进的MLP来建模全局依赖关系

具体来说,作者提出了具有轴向(即横向和纵向)全局依赖建模特征的稀疏MLP(sMLP)模块(如上图右所示)。sMLP显著降低了计算的复杂度,并允许采用金字塔结构进行多阶段处理。因此,sMLPNet能够在更小的模型上实现与Swin Transformer相同的图像识别性能。

在本文中,作者研究了Transformer的关键组成部分(即Self-Attention)是否是图像理解的真正关键因素 。基于过去视觉模型的设计思想,作者采用了在设计时采用了局部偏置和金字塔结构。此外,作者也采用了全局依赖建模的思想,但使用稀疏MLP模块来实现。

基于以上思想,作者建立了一个名为sMLPNet的无注意力网络,实现了SOTA图像识别性能。本文表明,自注意力可能不是视觉模型设计的核心组成部分。相反,正确使用局部偏置、金字塔结构和对计算复杂度的控制是设计高性能视觉模型的关键

3. 方法

3.1. Design Guidelines

在这项工作中,作者保留了CNN使用的一些重要的设计理念,并添加了受Transformer启发的新组件。设计指南如下:

1.采用类似于ViT、MLP-Mixer和Swin Transformer的架构,以确保一个公平的比较。

2.显式地将局部偏置注入到网络中。

3.探索不使用自注意模块的全局依赖关系。

4.在金字塔结构中进行多阶段处理。

3.2. Overall Architecture

上图展示了本文网络的整体架构。与ViT、MLP-Mixer和Swin Transformer类似,空间分辨率为H×W的输入图像被分割为不重叠的patch。作者在网络中采用了4×4的patch大小,每个patch被reshape成一个48维的向量,然后由一个线性层映射到一个c维embedding,整张图像可以表示为的tensor。

整个网络由四个阶段组成。除第一阶段从线性embedding层开始外,其他阶段从patch合并层开始,将空间维数减少2×2,将通道维数增加2倍。patch合并层由一个线性层实现,它以每个2×2个相邻patch的concat特征作为输入,输出合并后的patch的特征。然后,将新的图像token输入到token混合模块和通道混合模块中。

token混合模块如上图所示。在这个模块中,作者使用了3x3的深度卷积来注入局部偏置。这个操作非常高效,因为它包含很少的参数,并且在推理过程中需要很少的FLOPs。

此外,作者还尝试用sMLP模块来建模全局依赖关系。稀疏性权重共享特性 使得sMLP比原来的MLP模块更不容易过拟合。sMLP的计算复杂度的大大降低,使其能够在第一阶段中以的空间分辨率进行操作。

通道混合模块由FFN实现,与MLP-Mixer中的实现方式完全相同。FFN由两个线性层和一个GeLU激活函数组成。第一线性层将维度从D扩展到αD,第二层将维度从αD缩减回D,其中α是一个可调的超参数。

3.3. Sparse MLP (sMLP)

作者设计了一个稀疏MLP来解决原始MLP的两个主要缺点。首先,减少参数的数量以避免过拟合 ,特别是当网络在中等大小的数据集上进行训练时。其次,降低计算的复杂度 ,特别是当token的数量很大的情况下,以实现金字塔结构中的多阶段处理。

在稀疏MLP中,作者使用稀疏连接权重共享 来实现,如上图所示。sMLP中的每个token只直接与同一行或同一列上的token交互。此外,所有行和所有列可以分别共享相同的投影权重。它由三条路径组成。除了中间所示的直连映射外,另外两条路径分别负责沿水平方向和垂直方向混合token。

设表示输入token的集合。在水平混合路径中,将特征reshape为,并对每一个行应用一个权重为的线性层来混合信息。

在垂直混合路径上也应用了类似的操作,线性层的特征为权重为。最后,将三条路径的输出融合在一起,产生与输入tensor维度相同的输出tensor。作者用FC层来实现这一步:

sMLP结构的pytorch代码如下所示:

如果该模块重复两次,每个token就可以聚合整个二维空间的信息。换句话说,sMLP虽然直接连接稀疏,但却也能有效地获得了全局感受野。

下面来比较一些计算复杂度,本文sMLP的复杂度为:

MLP-Mixer的token混合部分的复杂度为:

可以看出,本文的方法将复杂度控制在了内,而MLP-Mixer为,其中。这使得本文的方法可以处理更大的N,并最终在金字塔结构中实现多阶段处理。

3.4. Model Configurations

基于上面提出的结构,作者提出了三种不同大小的模型,所有通道混合的FFN中的膨胀参数为α=3,不同结构的超参数设置如下:

其中为隐藏层的通道数量。

4.实验

4.1. Ablation Study

Local and global modeling

从上表可以看出,去掉局部建模之后,图像识别精度显著下降至80.6%。这表明,DWConv是一种非常有效的建模局部依赖关系的方法。去掉全局建模之后,图像识别精度显著下降至80.7%,因此局部建模和全局建模在sMLPNet中都是重要的。

接着,作者在不同阶段删除了sMLP,结果如上表所示。所以看出,每个stage的sMLP都是重要的。

Fusion in sMLP

作者尝试了不同的方法来融合sMLP中的特征,可以看出,本文的方法和另外两种轻量级的操作相比具备性能上的优势。

Branches in sMLP

作者研究了三个分支中残差分支的作用,可以看出,加入残差分支能够带来性能上的提升。

Multi-stage processing in pyramid structure

作者还比较了单阶段和多阶段版本的MLP网络的性能,可以看出,多阶段版本可以达到更高的准确率。

4.2. Comparison with state-of-the-art

此外,作者还比较了本文的sMLPNet和SOTA模型的性能、参数和计算量。结果表明,一个无注意的模型可以达到SOTA的性能。

5. 总结

基于 sMLP块,作者构建了一个MLP视觉识别模型sMLPNet。本文提出的sMLP块具有稀疏连接权重共享 的特性,sMLP通过分别沿轴向(即横向、纵向)聚合信息,避免了传统MLP的二次模型大小和二次计算复杂度。实验结果表明,这极大地促进了MLP视觉模型的性能。

当前,基于Transformer的模型已经达到了更高的性能。在性能上,本文的方法确实和目前的SOTA模型还存在一定的差距。不过呢,个人认为,作者提出本文的模型也不是为了追求极致的性能,而是为了展示无注意网络的表现,挑战自注意机制的必要性。

通过一系列的实验,作者也承认了基于MLP的网络结构有其固有的局限性。由于FC层的固定性质,MLP模型难以处理任意分辨率的输入图像。这使得MLP模型很难应用于一些重要的下游任务,如目标检测和语义分割。

▊ 作者简介

研究领域:FightingCV公众号运营者,研究方向为多模态内容理解,专注于解决视觉模态和语言模态相结合的任务,促进Vision-Language模型的实地应用。

知乎/公众号:FightingCV

0 人点赞