学习泛化能力:用于领域泛化的元学习

2021-03-19 14:32:49 浏览数 (1)

作者 | 何文嘉 编辑 | 李仲深

  • Abstract
  • Introduction
  • Methodology
    • MLDG in Supervised Learning
    • MLDG in Reinforcement Learning
  • Analysis of MLDG
  • Alternative Variants of MLDG
    • MLDG-GC
    • MLDG-GN
  • Experiments
    • Experiment I: Illustrative Synthetic Experiment
    • Experiment II: Object Recognition

‍Abstract

域偏移(Domain shift)是指在一个源域中训练的模型在应用于具有不同统计量的目标域时表现不佳的问题。领域泛化(Domain Generalization, DG)技术试图通过产生模型来缓解这一问题,通过设计将模型很好地推广到新的测试领域。提出了一种新的域泛化元学习方法。我们没有像以前的DG工作那样设计一个对域移位具有鲁棒性的特定模型,而是提出了DG的模型不可知论训练过程。我们的算法通过在每个小批中合成虚拟测试域来模拟训练过程中的训练/测试域偏移。元优化目标要求模型改进训练域性能的步骤也应该改进测试域性能。这一元学习过程训练模型具有良好的泛化能力的新领域。我们在最近的跨域图像分类基准上评估了我们的方法和达到的最先进的结果,并在两个经典的增强学习任务上展示了它的潜力。

原文:Li, Da, et al. "Learning to generalize: Meta-learning for domain generalization." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 32. No. 1. 2018.

链接:https://ojs.aaai.org/index.php/AAAI/article/download/11596/11455

Introduction

人类善于在许多不同的条件下解决任务。这在一定程度上是由于快速适应,但也是由于一生都遇到新的任务条件,提供了制定对不同任务环境强有力的战略的机会。如果一个人发现他们现有的策略在新的上下文中失败了,他们不仅会适应,而且还会进一步尝试更新他们的策略,使其更独立于上下文,以便下次他们到达新的上下文时,他们更有可能立即成功。我们希望人工学习代理在不同的条件(域)下解决许多任务,并类似地解决构建对域变化具有鲁棒性的模型的二阶任务,并在新域中表现良好。

标准学习方法在不同的条件下(即对具有不同统计数据的数据)用于培训时往往会崩溃。这被称为域或协变量偏移。解决这一问题的方法可分为领域适应(Domain Adaptation, DA)和领域推广(Domain Generalization, DG)。对DA的研究进展得相对较好,能够实现使用目标域中未标记或稀疏标记的数据,以快速使得在不同源域中训练的模型得以在目标域中适应。研究较少的DG则是用于构建即使在新的目标/测试域中也表现良好的模型。与DA相比,DG模型在训练后没有更新,其关键问题是它在新领域中的工作效果如何。现有的几种DG方法通常都是在多个源域上训练,并提出了提取一些描述已知域公共方面的与领域或模型无关的表征。

过去有关元学习的研究中与我们最相关的是MAML方法。MAML采用了一种元学习方法,通过在一组源任务上训练一个模型来进行少样本学习,该模型距离一个良好的任务特定模型只有几个梯度下降步的差距。这个元优化目标训练模型适合于对新的目标任务的少样本微调。但DG问题是不同的,因为我们要跨域而不是跨任务,而且DG假设为使用零样本适应目标域,而不是像MAML一样在目标问题是有少量的训练样本。

尽管不同的方法工具多种多样,但大多数现有的DG方法都建立在三种主要策略之上:

  • 最简单的方法是为每个源域训练一个模型。当测试域到来时,估计最相关的源域并使用该源域对应的分类器。
  • 第二种方法是假定任何域都由一个底层的全局共享因子和一个特定于域的组件组成。通过在源域训练期间分解特定于域和域无关组件,可以提取与域无关的组件作为一个模型(认为这一部分信息包含了源域与目标域的共同信息),该模型可能在一个新的源域上工作。
  • 最后一种方法,学习一个具备领域不变性的特征表示。如果这个表征可以学习用于最小化目标域与多个源域之间的差距,它应该提供一个独立于域的表示,其在新的目标域上也表现良好。

与这些研究相比,我们的研究MLDG(Meta-Learning Domain Generalization)是第一个通过元学习来解决领域泛化问题的研究结果。MLDG既可以用于监督学习也可以用于强化学习。

Methodology

MLDG in Supervised Learning

算法如Algorithm 1所示:

下面逐步分析算法的主要步骤。定义

l(hat{y}, y)=-hat{y} log (y)

分类任务的交叉熵损失。

Meta-Train阶段。划分

overline{mathcal{S}} = S - V

为用于元训练的域,而

V

是用于元测试的域,

S

是所有域的总和。

y_j^{(i)}

表示领域

N_i

的第

j

个样本。模型的参数为

Theta

,下面公式中的

mathcal{F}

给出了元训练的损失函数。

nabla_{Theta}

Theta

关于

mathcal{F}

的梯度,则参数更新公式为

Theta^{prime}=Theta-alpha nabla_{Theta}

。再回顾Algorithm 1的5-7行就不难理解了,就是普通监督学习算法的常规流程,无需细说。

mathcal{F}(.)=frac{1}{S-V} sum_{i=1}^{S-V} frac{1}{N_{i}} sum_{j=1}^{N_{i}} ell_{Theta}left(hat{y}_{j}^{(i)}, y_{j}^{(i)}right)

Meta-Test阶段。看下面的公式就知道与 Meta-Train 几乎一样,只是使用的数据是划分出来的元测试域。这里对应的是Algorithm 1的的8行。

mathcal{G}(.)=frac{1}{V} sum_{i=1}^{V} frac{1}{N_{i}} sum_{j=1}^{N_{i}} ell_{Theta}left(hat{y}_{j}^{(i)}, y_{j}^{(i)}right)

Summary阶段。其中

α

是元训练的步长,

β

是元训练和元测试之间的权重。使用梯度下降算法求解下式就可以得到Algorithm 1的第9行。

underset{Theta}{operatorname{argmin}} mathcal{F}(Theta) beta mathcal{G}left(Theta-alpha mathcal{F}^{prime}(Theta)right)

Final-Test阶段。当模型在源域上的收敛后,在真正被保留的目标域上部署了最终的模型

Theta

MLDG in Reinforcement Learning

算法如Algorithm 2所示:

通过与Algorithm 1对比,可以看出大同小异,此处不多赘述。值得补充的是在强化学习中,对于DG而言与监督学习有对应关系。监督学习(SL)中的任务对应的是强化学习(RL)中的奖励函数(Reward Function),SL中的域(domain)映射到RL中周围环境不同的相同任务。因此,DG将实现一个具有改进泛化能力的agent,在其操作环境发生变化的情况下,如果允许获得奖励,则对应SL中的监督领域适应(Supervised Domain Adaptation),如果不允许获得奖励则对应SL中的无监督领域适应(Unsupervised Domain Adaptation)。

至此,如果还不能直观地了解整个算法的本质的话无需担心,形象化的理解在后面的 Analysis of MLDG 部分会详细说明。

Analysis of MLDG

为了更好地理解MLDG的原理,作者下面做了数学推导和解释。

MLDG的目标是

underset{Theta}{operatorname{argmin}} mathcal{F}(Theta) beta mathcal{G}left(Theta-alpha mathcal{F}^{prime}(Theta)right)

其中

mathcal{F}

是来自元训练域损失的综合。

mathcal{G}

是来自元测试域损失的综合。

mathcal{F}'(cdot)

是训练损失

mathcal{F}(Theta)

关于

Theta

的梯度,即

nabla_Theta

。这可以通俗理解为:“模型需要优化以使得在更新元训练域之后,模型在元测试域上的性能也很好”。

对于MLDG目标的另一个角度,我们可以对公式中的第二项进行一阶泰勒展开:

mathcal{G}(x)=mathcal{G}(dot{x}) mathcal{G}^{prime}(dot{x}) times(x-dot{x})

其中,

dot{x}

是一个接近

x

的任意点。多变量形式时

x

是一个向量,而

mathcal{G}(x)

是一个标量。我们可以假设

x = Theta-alpha mathcal{F}^{prime}(Theta)

并且选择

dot{x}

作为

Theta

,则有

mathcal{G}left(Theta-alpha mathcal{F}^{prime}(Theta)right)=mathcal{G}(Theta) mathcal{G}^{prime}(Theta) cdotleft(-alpha mathcal{F}^{prime}(Theta)right)

那么目标函数变为:

underset{Theta}{operatorname{argmin}} mathcal{F}(Theta) beta mathcal{G}(Theta)-beta alphaleft(mathcal{G}^{prime}(Theta) cdot mathcal{F}^{prime}(Theta)right)

为了让式

(7)

最小化,这里揭示了:

  1. 根据前两项
mathcal{F}(Theta) beta mathcal{G}(Theta)

需要最小化元训练域损失

mathcal{F}(Theta)

和元测试域的损失

mathcal{G}(Theta)
  1. 根据最后一项
-beta alphaleft(mathcal{G}^{prime}(Theta) cdot mathcal{F}^{prime}(Theta)right)

,需要最大化元训练域损失的梯度与元测试域的损失的梯度的乘积

mathcal{G}^{prime}(Theta) cdot mathcal{F}^{prime}(Theta)

对于第一点很好理解,最小化损失以取得尽可能好的表现,而第二点,可以这么理解,

a cdot b=|a|_{2}|b|_{2} cos (delta)

δ

是向量

a

b

之间的角度,如果向量

a

b

是单位归一化的,

a cdot b

精确地计算余弦相似度。类似的,虽然

mathcal{F}(Theta)

mathcal{G}(Theta)

没有归一化,但如果这两个向量方向相似,点积仍然更大,所以第二点目标其实倾向于使得

mathcal{F}(Theta)

mathcal{G}(Theta)

之间的夹角最大化,即两个向量方向尽可能相似。由于

mathcal{F}(Theta)

mathcal{G}(Theta)

是两组域的损失梯度,因此“相似方向”意味着每组域的优化方向是相似的。因此,总体目标可以看作是:“调整参数,使得元训练和元测试的两个领域的损失最小化,并使它们以协调的方式下降(两者的损失都逐渐减小,而不是一个逐渐减小,另一个逐渐增大而导致的最终呈现整体的减小)”。而

arg min _{Theta} mathcal{F}(Theta) mathcal{G}(Theta)

这样的优化目标是不能体现上述的第二点的。所以,MLDG通过找到一条最小化路径来减少过度拟合到一个域,其中两个子问题在路径的所有点具有一致的梯度方向。

Alternative Variants of MLDG

根据上一部分的数学推导,可以顺其自然地修改原来的优化目标从而得到MLDG的一些变体。

MLDG-GC

(7)

中的第二项归一化,使得第二项的意义是最大化

mathcal{F}(Theta)

mathcal{G}(Theta)

之间的夹角:

underset{Theta}{operatorname{argmin}} mathcal{F}(Theta) beta mathcal{G}(Theta)-beta alpha frac{mathcal{F}^{prime}(Theta) cdot mathcal{G}^{prime}(Theta)}{left|mathcal{F}^{prime}(Theta)right|_{2}left|mathcal{G}^{prime}(Theta)right|_{2}}

MLDG-GN

关于“相似方向”梯度的另一个观点是,一旦元训练域收敛,也不再需要更新元测试域上的参数。在一个好的解决方案中,元测试梯度应该是接近于零。利用这种直觉就可以得到 MLDG-GN,即使得第二项的目标是让测试域上的损失梯度尽可能为0:

underset{Theta}{operatorname{argmin}} mathcal{F}(Theta) betaleft|mathcal{G}^{prime}left(Theta-alpha mathcal{F}^{prime}(Theta)right)right|_{2}^{2}

Experiments

Experiment I: Illustrative Synthetic Experiment

通过从对角线分类器中采样弯曲偏差来合成九个域。我们把其中的八个作为元学习的来源,并为最后的最后一次测试。

(a) 中的9张图代表9个采样得到的域,而 (b) 中是4种模型对比的结果,图中的MLP-ALL代表的是用于训练的聚合所有源域的简单基线。

这些结果表明,MLDG方法有助于避免过度拟合特定的源域,并学习一个更可推广的模型。

Experiment II: Object Recognition

作者使用了PACS多领域识别benchmark,这个数据集共有9991张图像,分为7个类别(“狗”、“大象”、“长颈鹿”、“吉他”、“房子”、“马”和“人”)和4个不同风格描述领域(“照片”、“艺术绘画”、“卡通”和“素描’)。多样化的描绘风格提供了一个显著的领域差距。其目标是在一组域中进行训练,并识别在一个不联合的域中的对象。结果如下图所示:

消融实验:

两种MLDG变体的对比:

关于强化学习的实验由于篇幅优先就不详细描述了,有兴趣的读者可以自行阅读原文。‍


代码

https://github.com/lishuya17/MONN

参考文献

https://www.sciencedirect.com/science/article/pii/S2405471220300818

0 人点赞