Domain Adaptation介绍
说起Domain Adaptation,首先要从迁移学习说起。迁移学习主要解决的是将一些任务(source domain)上学到的知识迁移到另一些任务(target domain)上,以提升目标任务上的效果。当目标任务有较充足的带标签样本时,迁移学习有多种实现方法。例如,采用Pretrain-Finetune的方式,先在源任务上Pretrain,再在目标任务上用一定量的数据Finetune;或者利用Multi-task Learning的方式,多个任务联合训练。然而,当目标任务没有带标签的数据,或者只有非常少量的带标签样本时,上述两种方法就无法采用了。因此,Domain Adaptation应蕴而生,主要解决目标任务没有数据或数据量非常少无法训练模型的场景。
Domain Adaptation的基础模型结构主要分为feature extractor和classifier两个部分。其中,feature extractor用来从source domain样本或target domain样本上提取特征表示,classifier用于根据feature extractor提取的特征进行具体的分类任务。Domain Adaptation的核心思路为,让feature extractor部分生成的source domain或target domain的特征表示是同分布的,即将source domain和target domain的特征表示对齐。这样后续的classifier就可以使用source domain数据上训练好的模型预测target domain的数据了,无需再用target domain有标签样本进行finetune,解决了target domain无有标签数据的迁移学习问题。
本文介绍了Domain Adaptation的基本原理和近几年来的顶会论文,带大家快速了解Domain Adaptation的SOTA方法。
Domain Adaptation基本方法介绍
正如上文所说,Domain Adaptation的核心思路是训练一个feature extractor,让其生成的source Domain和target Domain的特征分布一致。为什么让source Domain和target Domain的特征分布一致时Domain Adaptation效果最优,是在Analysis of representations for domain adaptation(NIPS 2017)中有理论支持的,两个domain分布差异是target domain预测误差的下界。业内主要有两类解决思路,分别是基于分布距离度量约束的方法和基于对抗学习的方法:
基于分布距离度量约束的方法:通过在模型的优化目标中引入feature extractor对于source domain和target domain生成表示的距离损失函数,达到约束两个domain生成特征分布一致的要求。在Domain Adaptation中常用的分布距离度量包括Maximum Mean Discrepancy、Wasserstein distance等,这些分布距离度量用来衡量两个分布的差异大小。
基于对抗学习的方法:同时训练feature extractor和一个discriminator,discriminator用来判断feature extractor生成的表示来自于哪个domain。经过两个优化任务的对抗训练,最终feature extractor生成的表示让discriminator无法分辨是来自source domain还是target domain,达到了不同domain生成特征同分布的目标。在整个过程中,discriminator替代了基于分布距离度量约束方法中衡量分布差异的距离度量函数。
可以看出,无论是哪种方法,核心都是衡量feature extractor生成的source domain和target domain表示的差距,并在模型优化目标中最小化该差距。目前业内的主流前沿方法集中在基于对抗学习的Domain Adaptation方法。下面我们将介绍多篇近几年顶会中对基于对抗学习的Domain Adaptation的优化方法。
基于对抗学习的Domain Adaptation
第一个基于对抗学习的Domain Adaptation工作是Unsupervised Domain Adaptation by Backpropagation(ICML 2015,UDA),该工作提出的框架至今仍然被各个工作作为基础。DAN模型的主体结构如下图,主要包括Feature Extractor、Domain Classifier(Discriminator)、Classifier三个部分。模型训练过程中,输入一批来自Source Domain或Target Domain的样本,经过Feature Extractor后生成特征表示。一方面,对于有label的Source Domain数据,特征表示会进入Classifier进行分类任务学习,损失函数为交叉熵损失。另一方面,Source Domain和Target Domain的特征表示都会进入Discriminator,Discriminator根据特征表示预测该样本是来自Source Domain还是Target Domain。模型的优化目标由下面两个优化目标组成:
其中D代表Discriminator,D的目标是分辨出Source Domain和Target Domain;G代表Feature Extractor,G的目标是既让图像分类任务预测的好,同时生成的feature能够欺骗Discriminator(最大化Domain分类loss,对抗Discriminator)。整个模型端到端训练,其中使用了gradient reversal layer(GRL),即在反向传播的过程中将Discriminator的梯度加上负号用来更新Feature Extractor的参数,达到对抗学习的目的。
基于对抗学习DA的后续优化
在上文介绍的基于对抗学习DA方法的基础上,学术界针对该方法的不同问题提出了不同的优化方法,主要包括生成任务相关的一致性表示、学习Domain-specific表示辅助Domain-invariant表示两个方面。
生成下游任务相关的对齐表示
基础的对抗学习DA方法在生成Source Domain和Target Domain一致性表示的时候,没有考虑下游的预测任务,只是强制约束两个Domain的特征表示分布一致。这带来的问题是,生成的特征表示可能无法在后续具体的分类任务中取得较好的效果。因此,一些研究针对如何在对齐Source Domain和Target Domain的基础上,生成对下游具体任务有区分性的表示。
Maximum Classifier Discrepancy for Unsupervised Domain Adaptation(CVPR 2018)认为,传统的对抗学习DA方法在对齐两个domain feature时,没有考虑target样本和具体task分类边界之间的关系。如下图所示,现有的方法直接将两个domain特征对齐,而对齐后的特征无法被source domain上训练超平面(模型)有效区分。因为source domain有label而target domain无label,因此target的分类实际上是在复用source domain的分类平面,如果target domain的特征没有考虑到source domain的分类超平面,就可能导致对齐后的feature无法有效使用source domain的分类器进行分类。本文优化的核心思路是,寻找那些用source分类器进行分类效果不好的target样本,让feature generator产生的target样本表示更好的被source分类器分类。
具体的做法示意图如下。首先生成两个不同的source domain分类器,然后让feature generator生成的表示能够尽可能减小两个分类器的预测结果的不一致性(disagreement),即如果一个target样本的表示被两个不同的分类器分类产生的结果不一致的话,这个样本很有可能是无法被source分类器很好区分的,因而feature generator对于这个样本生成的特征表示很有可能是不适用于下游任务的。本文对两个分类器引入Maximize Discrepancy目标,让两个分类器产生的结果尽可能差异大,同时让feature generator生成的表示经过两个分类器后尽可能差异小,这也是一个对抗学习的过程,同时Maximize Discrepancy让两个分类器尽可能产生差异,防止两个分类器最后学的一样而失去了判别样本disagreement的能力。在下图中,黑线表示source domain学到的准确分类器,阴影表示target domain中让两个分类器产生disagreement的样本。在Maximize Discrepancy过程中,两个分类器分歧增大;在Minimize Discrepancy中,feature generator产生的表示能更好的被source domain分类从而减小分歧。
ToAlign: Task-oriented Alignment for Unsupervised Domain Adaptation(NIPS 2021)提出另一种方法来生成便于下游任务分类的Domain对齐方法。在以往的思路中,通过直接将source domain的表示和target domain表示全部对齐。然而,有一些和分类任务不相关的表示,这些表示不被对齐对下游任务没有坏处,同时又能让和任务相关部分的表示更好的对齐。例如,当进行图像中实体分类时,对齐不同domain图像中的实体部分特征表示非常重要,而不同domain图像的背景区域不需要对齐。本文的思路为,识别出哪部分是和下游任务相关的,哪些是无关的,然后只对齐和下游任务相关部分的表示。两种思路的对比如下图所示。
为了识别表示中哪部分是和下游任务相关的,本文采用Grad-CAM方法。Grad-CAM/CAM通过图像分类层最后一层的输出的权重,衡量上一层生成的表示每一个channel的重要性,再对各个channel各个像素点的值加权,得到对于分类最重要的像素点,如下图所示。通过这种方式,可以识别出对于当前分类任务来说,哪些像素点是和任务最相关的。
识别出哪些像素点和任务最相关,接下来就可以在表示对齐的时候,通过像素的任务重要性加权实现有选择性的对齐。具体的实现比较简单,相比baseline模型,Discriminator的输入在生成的表示上乘每个表示(像素点)对于下游任务的权重,来实现对任务相关的额表示进行迁移,对任务不相关的表示不进行迁移。模型结构如下图:
学习Domain-specific表示辅助Domain-invariant表示
基本的domain adaptation方法对齐source domain和target domain的表示,但是没有考虑每个domain独有的信息。如果能将domain-specific的信息提取出来,有助于对domain-invariant部分的学习。Domain Separation Networks(NIPS 2016)提出使用private-share类型的网络结构实现公共部分和私有部分分离的结构。整个网络包含Encoder、Decoder、Classifier三个部分。Encoder包括一个两个Domain公用的Encoder,以及每个Domain特有的Encoder;Decoder用来根据shared encoder和private encoder还原图片,是两个domain公共的;classifier根据source domain的表示进行图像分类。
模型的损失函数主要包括4个部分,除了任务分类损失和decoder的reconstruction损失外,包含difference loss和similarity loss。Different loss用来让private-encoder和shared-encoder编码不同的表示,为了达到这个目的,通过正交损失进行约束(让两个向量正交):
Similarity loss让两个domain生成的表示更相似,采用经典的是用MMD度量两个Domain生成表示的距离,并用GRL方式实现训练。
Heuristic Domain Adaptation(NIPS 2020)中提出,学习domain-specific表示要比domain-invariant表示更容易,因此先学domain-specific表示,再用总的表示减去domain-specific表示,就可以得到domain-invariant表示。基于这个思路,本文提出了一种启发式的domain adaptation框架。在下面的模型结构图中,F(x)表示图像整体的表示,G(x)为通过对抗学习提取到的domain-invariant表示,H(x)为通过启发式方法学习到的domain-specific表示,G(x)=F(x)-H(x)。通过学习H(x),将doman specific部分从F(x)中有效的去除,得到G(x)。要注意的是,H(x)只需要去除掉G(x)中根据对抗学习学到表示中的domain specific部分(通过对抗学习已经能去掉一部分domain specific表示了,这部分表示可以不管),因此H(x)和G(x)应该具有相似的domain specific部分。
中心极限定理表明:对于混合信号,其概率密度比任何一个源信号的概率分布都接近高斯分布;反过来,最大化信号的非高斯性与最大化信号的统计独立性是一致的,ICA就是利用了这一原理进行独立成分分析。为了让F(x)更好的被拆分成两个独立的表示,本文提出引入对F(x)的非高斯性度量作为约束。最终,网络的损失函数由两部分组成,分别是Generator Loss(对抗学习DA的分类损失和domain分类损失)以及heuristic损失:
总结
本文介绍了Domain Adaptation的基本方法,主要介绍了基于对抗学习的Domain Adaptation框架,以及在此框架基础上的优化方法,包括4篇近5年的顶会论文。核心思路是,让feature generator生成的source domain和target domain表示同分布,以此实现source domain上训练的模型可以直接应用在target domain的目标。当我们面临目标域样本缺少有标签数据时,Domain Adaptation是一个有力方法。