标题&作者团队
Paper: https://arxiv.org/abs/2101.07525
Code: https://github.com/zengarden/momentum2-teacher
【导读】本文是Face 的研究员提出的一种自监督方法,它针对现有老师-学生模式自监督方案中BN统计信息更新需要大batch问题,结合老师模型无需梯度传播的机制,提出了一种新颖的MomentumBN。所提方法仅需比较的batch即可取得与原始方案大batch相当的性能;甚至,MomentumBN可以进一步提升BYOL、MoCo等方法的性能。本文值得自监督学习方向的同学研究一番。
Abstract
本文提出一种新颖的Mementum2Teacher方法用于学生-老师模式的自监督学习,所提方法采用momentum方式对网络参数与BN的统计参数进行更新:老师模型的权值参数根据学生模型的权值参数采用momentum方式进行更新,老师模型的BN统计信息则根据自身的历史信息通过momentum方式更新。所提Momentum2Teacher方法极为简单&有效,在batch=128的情形下,所提方法可以在ImageNet的线性评估模式下取得74.5%的精度,且不需要在特定的硬件(如TPU)或者低效的跨GPU操作(比如SyncBN、ShuffleBN)。
本文的主要贡献包含以下几点:
- 提出一种新颖的Momentum2Teacher自监督方法,它具有高效性、硬件友好性(小batch),同时可以取得与大batch相当的性能;
- 所提核心MomentumBN有益于所有学生-老师模式的自监督方法,它可以进一步改善MoCo与BYOL的性能;
- 所提方法在ImageNet(线性评估)取得了74.5%的top-1精度。
Method
在正式介绍本文核心之前,作者做了这样一个实验:统计信息的重要性;然后再引出本文的核心MomentumBN;最后给出Momentum2Teacher方案。
Importance of Stable Statistics
为更好分析学生-老师框架中BN统计信息的重要性,作者采用STL10数据集、BYOL方法作为基线,进行了四组不同BN统计的性能对比,结果见下表。
从上表对比结果我们可以看到以下四点发现:
- SyncBN非常重要。当移除掉SyncBN后,BYOL的性能直接从88.06%下降到了84.16%。这说明了BN中的稳定统计信息的重要性。
- SyncBN会降低训练速度。SyncBN会涉及到跨GPU信息交互,可以看到速度降低高达4倍。
- 同时对学生模型与老师模型添加SyncBN并非必要。仅仅对其中一个添加SyncBN可以分别得到87.80%和87.12%。这说明:老师与学生模型中的BN可以解耦设计。
- 稳定的老师模型非常重要。老师模型中添加SyncBN要比学生模型中添加SyncBN更加(87.80% vs 87.12%)。
Momentum BN
SyncBN是通过简单的大batch获得稳定的统计信息,而在学生-老师框架中,老师模型不会进行梯度传播。如果我们可以利用该特性,那么我们可以采用更小的batch或者稳定的统计信息。
BN操作的两个重要统计信息定义如下:
然后再通过线性变换得到最后的输出:
在学生-老师框架中,上述两个统计信息是通过当前batch的样本进行统计,而
则是学生模型的momentum更新。由于我们不需要对
进行更新,而老师模型可以视作学生模型的时序集成,因此我们提出了如下的BN统计信息的momentum更新:
其中
表示momentum系数。所提过程是通过momentum方式更新BN统计,故而将该过程称之为MomentumBN。
Lazy update 信息泄露是自监督学习的一个主要问题。在BYOL方法中,它通过如下方式统计统计同一个样本两种变换
的损失:
在完成损失
后,
的统计信息将被送入到老师模型中。当计算
损失时,如果我们直接采用MomentumBN,那么会将
的统计信息包含在老师模型中。这会使得学习过程变得更为琐碎,进而影响模型的性能。
为解决上述问题,我们提出了Lazy update机制。首先,分别对
执行MomentumBN,计算方式如下:
然后,通过上述统计信息执行BN的线性变换;最后再通过如下方式进行更新:
前面图示的Table1中也给出了采用MomentumBN替换老师模型中的SyncBN的性能对比,可以看到:MomentumBN取得了更好的性能(88.18 vs 87.8)且无需跨机器通讯;除了可以提升性能外,还可以加速模型训练。
Momentum2Teacher
从前面的图示Table1结果可以看到,在MomentumBN辅助老师模型时,batch=32的学生模型已经可以取得与batch=2048的SyncBN方案同等性能,这说明:相比老师模型,学生模型可以采用更小的batch进行BN统计。因此,本文提出了这样的组合:student with small batch teacher with MomentumBN
。由于所提方案中采用两次momentum机制,故将其称之为Momentum2Teacher,其结构示意图如下。
image-20210123170732137
Implementation
Baseline正如上述所示,我们采用BYOL作为基线,BYOL通过最大化样本X的两种增广
的相似性学习特征表达。
通过学生模型(包含encoder
, MLP
以及预测器
)进行处理,
则通过老师模型(包含encoder
, MLP
以及预测器
)进行处理。老师模型的参数更新方式如下:
Image augmentations 在数据增广方面,我们采用了与BYOL和SIMCLR相同的方式,即随机裁剪、水平镜像、Resize、ColorDistortion、GaussianBlur、GrayScale等。
Architecture 在STL10数据及上,本文采用ResNet18作为encoder,输出特征为512维,MLP的第一个线性层输出为512,第二个输出为128.
Training 64个2080TiGPU用于模拟SyncBN的跨机器通讯。SGD优化器,cosine学习率机制。momentum系数m从
开始并逐渐衰减到0;MomentumBN中的momentum系数
则从
开始并逐渐衰减到0,两者的衰减方式如下:
Experiments
前面在STL10数据集上已经验证了所提方案的有效性。接下来,我们将在ImageNet数据集上进行更进一步的分析论证,默认采用ResNet50。
Effectiveness
首先,我们来看一下MomentumBN在BYOL框架中的有效性,结果见下表。从中可以看到:(1) batch=128时,移除掉SyncBN后,BYOL的性能从72.5下贱管道了61.5;(2) 而MomentumBN则可以将其性能提升到72.9;(3) Momentum2Teacher的训练速度与无SyncBN的BYOL相当。
image-20210123164031356
然后,我们再来看一下MomentumBN在MoCoV2框架中的有效性,结果见下表。可以看到:更稳定的统计同样有益于MoCo,验证了MomentumBN的泛化性能,与此同时,MomentumBN可以加速训练。
image-20210123164109939
Small Batch-Size
接下来,我们再来看一下不同batch对于Momentum2Teacher的影响性。
image-20210123164427947
image-20210123164441108
从上面的图示可以看到:Momentum2Teacher的batch可以非常小。当batch=32时,所提方法可以取得了与batch=512的BYOL相当的性能;而BYOL的性能会随batch的减小而训练下降。
State-of-the-art
再接下来,我们看一下所提方法与其他知名方法(包含MoCo、SimCLR、BYOL等)的性能对比。
Transfer Learning
最后,我们再来看一下所提方案训练的模型在下游任务上的性能对比,结果如下,结果非常赞。
image-20210123164921321
全文到此结束,更多消融实验分析建议查看原文。