机器之心专栏作者:陈小康
来自北京大学、香港大学和百度的研究者近日提出了一种名为CAE的新型 MIM 方法。
掩码建模方法,在 NLP 领域 (例如 BERT) 得到了广泛的应用。随着 ViT 的提出和发展,人们也尝试将掩码图像建模(MIM)应用到视觉领域并取得了一定进展。在此之前,视觉自监督算法主要沿着对比学习(contrastive learning)的思路去设计,而 MIM 无疑打开了新的大门。
来自北京大学、香港大学和百度的研究者近日提出了一种名为CAE的新型 MIM 方法。该方法通过对 “表征学习” 和 “解决前置任务(pretext task)” 这两个功能做完全分离,使得编码器学习到更好的表征,从而在下游任务上实现了更好的泛化性能。
论文地址:https://arxiv.org/abs/2202.03026
该研究回答了如下几个问题:
1.MIM 方法中,网络结构的哪个部分是学习表征的,哪个部分是解决前置任务的?2. 为什么之前典型的对比学习方法,在下游任务 (例如检测、分割) 上只能取得跟监督预训练方法类似的性能?3.MIM 方法为什么优于目前的对比学习方法?
1. 背景
MIM 是一种自监督表征学习算法。它的主要思路是,对输入图像进行分块和随机掩码操作,然后对掩码区域做一些预测。预测的目标可以是 Token ID (BEiT),也可以是 RGB 的值 (MAE)。编码器能够通过 MIM 学得一个好的表征,从而在下游任务上取得良好的泛化性能。
近期 MIM 有两个代表性工作:BEiT 和 MAE。
- BEiT 使用一个编码器做两件事:(1) 学习一个好的图像表征;(2) 解决前置任务:预测掩码 patch 的 Token ID。编码器的潜力并没有完全被挖掘,只有部分被用来学习表征。
- MAE 使用了编码器-解码器架构,编码器负责对可见 patch 进行表征学习,解码器将可见和掩码patch的表征(使用一个可学习的向量)作为输入,预测掩码 patch 的 RGB 值。但是,MAE 在解码器中也会对可见 patch 的表征进行更新,实际上解码器也负责了一部分学习表征的功能。
以上两种方法,都没有充分挖掘编码器的潜力,限制了预训练学习到的表征质量。
2. Context Autoencoder (CAE)
CAE 设计的核心思想是对 “表征学习” 和 “解决前置任务” 这两个功能做分离。研究者希望在预训练时,编码器只负责表征学习,解码器只负责解决前置任务,这样可以尽可能大地挖掘编码器的潜力。CAE 包括 4 个部分:(1) Encoder; (2) Latent contextual regressor; (3) Decoder; (4) Alignment模块。
输入图像通过随机掩码被划分成可见 patch 和掩码 patch 两个部分。具体来说:
- 编码器(Encoder)是一个 ViT 模型,负责学习可见 patch 的表征
。
- Latent contextual regressor 通过
预测掩码 patch 的表征
。Latent contextual regressor 由一系列交叉注意力(cross-attention)模块组成,query 是掩码 patch 的表征,key 和 value 是全部 patch 的表征。在计算 query-key 相似度时,该方法会引入每个 patch 对应的位置编码。在这个阶段,
不断更新、变得更加准确,而
不会更新,对图像特征的提取这个任务完全交给编码器。
- 解码器(Decoder)只拿
和对应的位置编码作为输入,通过
预测掩码 patch 的某些性质,比如 Token ID,或者 RGB 的值。该研究的实验与 BEiT 类似,使用 DALL-E tokenizer 对输入图像 token 化,得到解码器的目标。
- 潜在表征对齐(Latent representation alignment)通过对
添加约束,希望Latent contextual regressor 的输出和编码器的输出在同一编码空间中。该方法将图像的掩码 patch 也输入到编码器,获得这部分的表征
。
将作为
学习的目标。计算
的过程不会计算梯度。
- 损失函数。损失函数由两部分组成:(1) 对解码器预测的监督,使用交叉熵损失; (2) 对
和
的对齐的监督,使用 MSE损失。
3. 分析
3.1 CAE 关注每个 patch 的表征
CAE 基于可见 patch 的表征,从随机采样的掩码 patch 中做一些预测,这要求 CAE 关注每个 patch 的语义。这不同于典型的对比学习方法 (例如 MoCo v3, SimCLR),不是只关注图像的全局语义而忽略图像的细节和非主体区域 (比如背景)。
3.2 Latent contextual regressor 的输出和编码器的输出在同一编码空间中
该研究对 Latent contextual regressor 的输出做了约束,希望它能和编码器的输出尽可能接近。这样,解码器会基于编码器学到的编码空间做预测,将对图像的特征提取的重任完全交到了编码器手上,驱使编码器学习到好的表征。
为了验证这一点,该研究用 RGB 值作为解码器目标 (考虑到 Token ID 难以可视化,这里使用 RGB),训练 CAE。在测试的时候,该研究将全部 patch 输入到编码器,然后跳过 Latent contextual regressor,直接将编码器的输出送进解码器,预测全部 patch 的 RGB 的值。下图展示了预测结果,第一行是原图,第二行是预测,研究者发现仅使用编码器和解码器就可以将图片重建出来,说明编码器的输出和 Latent contextual regressor 的输出属于同一编码空间。
如果训练时不做对齐约束,那么就无法重建,如下图所示,输出都是乱码,说明编码器输出和 Latent contextual regressor 的输出不在一个编码空间中。这使得编码器学到的表征质量有所欠缺,在消融实验部分也有验证。
3.3 CAE 学到的表征可以区分不同类别的对象/stuff
CAE 基于可见 patch 的表征,在掩码 patch 区域做预测,这要求 CAE 对可见 patch 的内容有比较好的理解。举例来说,人们看到一只狗的头部,可以预测出它的身体部分;看到一小片天空,也能预测出它的周围大概率也是一片天空。因此,研究者认为 CAE 学到的表征可以区分不同类别的对象/stuff。为了验证这一点,研究者从 ADE20K 数据集随机采样一些图片输入到编码器。因为 ADE20K 提供了每个像素的类别标签 (150 类),因此该研究可以使用 t-SNE 对编码器输出的表征进行可视化。如下图所示,每个颜色代表一个类别,左图是 CAE,右图是随机初始化的编码器。研究者发现 CAE 可以有效区分不同类别的对象/stuff (因为是在 ImageNet-1K 进行预训练,所以区分得不够完美),而随机初始化的编码器无法做到这一点。
3.4 典型的对比学习为什么在下游任务只能取得跟监督预训练差不多的结果?
在对比学习中,随机剪裁(random crop)是一个非常重要的数据增强策略。典型的 对比学习(比如 MoCo v3)希望最大化来自同一图像的 2 个不同剪裁之间的全局语义相似度,而最小化来自不同图像的剪裁之间的相似度。
这样为什么能奏效呢?研究者首先分析了随机剪裁的性质。在 SimCLR 论文中提到,随机剪裁是对比学习方法中非常重要的数据增强策略。在 ImageNet-1K 数据集中,图像的主体对象大多处于图像的中心区域,而对图像进行随机剪裁,中心区域有很大的概率被囊括进去,例如下图展示的几个例子,几次剪裁基本都包括了图像的主体对象。
对同一图像的不同剪裁提取全局语义,实际上学到的是原始图像中主体对象的特征,正因如此,同一图像的不同剪裁之间才可能相似。在监督预训练中,受到图像分类标签的约束,网络学习到的也是图像主体区域的特征,这和对比学习学到的知识有很大的相似之处,因此在下游任务表现类似。
3.5 MIM 和对比学习的区别
MIM 方法 (例如 CAE) 基于可见 patch 的表征,对掩码 patch 区域做预测。在做随机掩码时,图像的每个 patch (例如背景区域的对象/stuff) 都有可能被考虑到,而不仅仅是图像的主体区域。为了做好掩码 patch 的预测,CAE 会学好每个 patch 的表征。
该研究对 CAE 以及 MoCo v3 的注意力图做了可视化。如下图所示,第一行是原图,第二行是 MoCo v3,第三行是 CAE。红色表示注意力值更高,蓝色表示注意力值低。处于蓝色边界内部的区域,通过这样的原则筛选:将注意力值从大到小排序后,保留累计和达到所有位置注意力值总和的 50% 的部分。可以看到,MoCo v3 的注意力图主要在图像的主体区域有高响应,而 CAE 能考虑到几乎所有 patch。
4. 实验
该研究使用 ViT-small 和 ViT-base 在 ImageNet-1K 上进行实验。输入图像的分辨率是 224 X 224,patch 大小是 16 X 16,一张图会被划分成 14 X 14 个 patch。每次有 75 个 patch 被随机掩码。
4.1 预训练评估
自监督学习广泛使用线性探测(linear probing)去评测预训练表征的好坏:将编码器的参数固定住,在之后加一个线性分类器进行图像分类。研究者认为线性探测不适合 MIM 方法,因为 MIM 方法通常会学到每个 patch 的表征,不仅包含主体对象的信息,还学到了背景等若干知识,这是多而杂的,不适合直接进行线性分类。因此,研究者提出了一种新的测试指标:注意力探测(attentive probing)。该研究在固定参数的编码器后加上一个简单的交叉注意力模块(没有 FFN)和一个线性分类器,通过注意力机制动态地选择适合做图像分类的信息。
该研究对注意力探测阶段使用的交叉注意力模块做注意力图可视化,发现可以关注到主体对象。
微调、线性探测、注意力探测的结果见下表。
研究者发现一些有趣的现象。(1) 对比学习方法 (MoCo v3, DINO) 的线性探测和注意力探测结果类似。这说明这类方法在预训练时已经将注意力放到了图像的主体对象上面,无需进一步动态筛选即可做好图像分类,这也与之前研究者对对比学习的分析一致。(2) MIM 方法 (例如 CAE) 的注意力探测相比线性探测有很大的提升。这说明 MIM 方法学到了每个 patch 的特征,而不仅仅是图像主体对象的,因此需要做一些筛选才利于图像分类。
4.2 消融实验
该研究对解码器和对齐模块进行消融实验,见下表。单加一个解码器能改进注意力探测的结果,但在下游任务 (分割、检测) 上的提升不明显。使用对齐模块之后能显著提升下游任务的性能,说明约束编码器的输出和 Latent contextual regressor 的输出在同一编码空间非常重要,能提升编码器学到的表征质量。
4.3 语义分割
该研究在 ADE20K 上进行语义分割的实验。网络使用 UperNet,迭代次数为 160K,输入图像分辨率为 512 X 512,使用单尺度测试。对比学习方法和监督预训练方法(DeiT)的结果类似,而 CAE 能取得明显更好的结果。跟其他 MIM 方法相比,CAE 的结果也更好,说明预训练阶段编码器被充分利用,学到的表征更好。
4.4 目标检测、实例分割该研究使用 Mask-RCNN 和 Cascade-RCNN 两种网络结构进行目标检测和实例分割的实验。其中,使用多尺度训练 12 epoch,测试阶段仅使用单尺度测试。实验结果和语义分割类似:对比学习方法和监督预训练方法的结果类似且较差,CAE 的结果更好。
5 总结
该研究提出了 CAE,设计的核心有两点:(1) 对 “表征学习” 和 “解决前置任务” 这两个功能做完全分离; (2) 在可见 patch 学习到的表征空间中对掩码 patch 做预测。以上两点都是为了驱使编码器学习更好的表征,从而在下游任务取得良好的泛化能力。
此外,该研究对监督预训练方法、对比学习和 MIM 方法进行了分析,认为对比学习和监督预训练主要关注图像的主体区域 (例如 ImageNet-1K 标签集中的对象),而 MIM 会关注图像的全部 patch,更有利于下游任务。
© THE END
转载请联系本公众号获得授权
投稿或寻求报道:content@jiqizhixin.com