作者:Shanghua Gao 等 论文题目:Masked Diffusion Transformer is a Strong Image Synthesizer 来源:ICCV2023 论文链接:https://arxiv.org/abs/2303.14389 内容整理:王怡闻 尽管扩散模型在图像合成方面取得成功,但我们观察到它通常缺乏上下文推理能力,无法学习图像中物体部分之间的关系,导致学习过程较慢。为了解决这个问题,我们提出了一种掩码扩散Transformer(Masked Diffusion Transformer,MDT),明确增强了DPMs在图像中物体语义部分之间上下文关系学习的能力。在训练过程中,MDT在潜在空间上操作,对某些标记进行掩码。然后,设计了一个不对称掩码扩散Transformer,以从未被掩码的标记中预测被掩码的标记,同时保持扩散生成过程。我们的MDT可以从不完整的上下文输入中重建图像的完整信息,从而使其能够学习图像标记之间的相关关系。实验结果显示,MDT实现了出色的图像合成性能,例如在ImageNet数据集上实现了新的最先进的FID分数,并且比先前的最先进方法DiT学习速度快了大约3倍。
引言
在这项工作中,我们首先观察到DPMs通常难以学习图像中物体部分之间的关联关系,导致训练过程缓慢。为了解决这个问题,提出了一种有效的掩码扩散变换器(Masked Diffusion Transformer,MDT),以提高DPMs的训练效率。MDT引入了一个蒙面潜在建模方案,专门为基于Transformer的DPMs设计,以明确增强上下文学习能力并改进图像语义之间的关联关系学习。MDT在潜在空间中进行扩散过程以节省计算成本。它对某些图像标记进行掩码,并设计了一个不对称的掩码扩散变换器(AMDT),以一种扩散生成的方式预测被掩码的标记。MDT可以从其上下文不完整的输入中重建图像的完整信息,学习图像语义之间的关联关系。
通过这种掩码潜在建模方案,我们的MDT可以从其上下文不完整的输入中重建图像的完整信息,学习图像语义之间的关联关系。如下图所示,MDT通常在几乎相同的训练步骤生成了狗的两只眼睛(和两只耳朵),这表明它通过使用掩码潜在建模方案正确地学习了图像的相关语义。相比之下,DiT不能轻松地合成具有正确语义关系的狗。这种比较显示了MDT相对于DiT的优越关系建模和更快的学习能力。
图1
实验结果表明,MDT在图像合成任务上表现出更高的性能,大大改善了训练过程中的效率。它在ImageNet数据集上表现突出,并比最先进的DPMs(即DiT)在训练期间的学习速度快了约3倍。
图2
方法
图3
训练阶段的潜在掩蔽迫使扩散模型从其上下文不完整的输入中重建图像的完整信息。因此,该模型学习到了图像潜在标记之间的关系,特别是图像中语义之间的相关关系。例如,模型首先应该很好地理解狗图像的小部分(标记)之间的正确关联关系。然后,它应该使用其他未被掩蔽的标记作为上下文信息来生成被掩蔽的“眼睛”标记。此外,MDT通常在几乎相同的训练步骤中学习生成图像的相关语义,比如几乎在相同的步骤中生成狗的两只眼睛(两只耳朵)。而DiT(带有变换器骨干的DDPM)首先学会生成一只眼睛(一只耳朵),然后在大约10万次训练步骤后才学会生成另一只眼睛(耳朵)。这表明MDT在学习图像语义的相关关系方面具有卓越的学习能力。
在接下来的部分,我们将介绍MDT的两个关键组件:1) 潜在掩蔽操作,和2) 不对称掩蔽扩散变换器。
潜变量掩码
在潜在扩散模型(Latent diffusion model,LDM)中,MDT采用了在潜在空间而非原始像素空间中执行生成学习的方法,以减少计算成本。在训练过程中,首先向图像的潜在嵌入
添加高斯噪声。然后,按照[31]的方法,我们将带有噪声的嵌入
划分为一系列大小为
的标记,并将它们连接成一个矩阵
,其中
是通道数,
是标记的数量。接下来,我们以比率
随机掩蔽标记,并将剩余的标记连接成
,其中
。因此,我们可以创建一个二进制掩蔽
,其中一个(零)表示被掩蔽的(未被掩蔽的)标记。最后,我们将未被掩蔽的标记
输入到我们的扩散模型进行处理。我们只使用未被掩蔽的标记
,原因如下:
- 模型应该专注于学习语义,而不是预测被掩蔽的标记。如第5.3节所示,这比用可学习的掩蔽标记替换被掩蔽标记,然后处理所有标记可以获得更好的性能;
- 与处理所有
个标记相比,这节省了训练成本。
不对称掩蔽扩散变换器
图4
位置感知编解码器
在MDT中,从未被掩蔽的标记中预测被掩蔽的潜在标记需要考虑所有标记的位置关系。为了增强模型中的位置信息,我们提出了位置感知编码器和解码器,有助于学习被掩蔽的潜在标记。具体而言,编码器和解码器通过添加两种类型的标记位置信息来定制标准的DiT块,分别包含N1和N2个定制块。
首先,编码器将传统的可学习全局位置嵌入添加到噪声潜在嵌入输入中。同样,解码器在输入中也引入了可学习的位置嵌入,但在训练和推理阶段采用不同的方法。在训练期间,边插值器已经使用了下面介绍的可学习全局位置嵌入,它可以将全局位置信息传递给解码器。在推理期间,由于边插值器被丢弃,解码器明确将位置嵌入添加到其输入以增强位置信息。
其次,编码器和解码器在计算自注意力的注意分数时,为每个块中的每个头部添加了本地相对位置偏差:
、
和
分别表示自注意力模块中的查询(query)、键(key)和数值(value)。
是键的维度。
是相对位置偏差,根据位置之间的相对差异
(即
位置与其他位置的差异)选择。可学习的映射
在训练期间会更新。
函数用于将分数转换为权重,用于加权值。
编码器接收未被掩蔽的噪声潜在嵌入,然后在训练和推理中将其输出馈送给边插值器或解码器。对于解码器,其输入可以是用于训练的边插值器的输出,或者用于推理的编码器输出和可学习的位置嵌入的组合。因为在训练期间,编码器和解码器分别处理未被掩蔽的标记和完整的标记,所以这个模型被称为"不对称"模型。这种模型的设计有助于处理掩蔽标记和未掩蔽标记的不同情况,从而提高了模型的性能。
边插值器
在训练期间,为了提高效率和性能,编码器仅处理未被掩蔽的标记
。然而,在推理阶段,由于没有掩蔽,编码器需要处理所有标记
。这意味着在训练和推理期间,至少在标记数量方面,编码器的输出(即解码器输入)存在很大差异。为了确保解码器始终在训练预测或推理生成中处理所有标记,由一个小型网络实现的边插值器的作用是在训练期间从编码器的输出中预测被掩蔽的标记,并在推理期间将其移除。
在训练阶段,编码器处理未被掩蔽的标记,以获取其输出标记嵌入
。然后,如图3所示,边插值器首先使用一个共享的可学习掩蔽标记来填充掩蔽位置,这些位置由掩蔽
指示,还添加了可学习的位置嵌入以获得嵌入
。接下来,我们使用编码器的基本块来处理
以预测一个插值嵌入
。
中的标记表示预测的标记。最后,我们使用一个带有掩蔽的快捷连接来组合预测
和
,得到
。总之,对于被掩蔽的标记,我们使用边插值器的预测;对于未被掩蔽的标记,我们仍然采用
中的相应标记。这可以实现:
- 增强训练和推理阶段之间的一致性,
- 消除解码器中的掩蔽重建过程。
由于在推理期间没有掩蔽,边插值器被位置嵌入操作替代,该操作添加了在训练期间学习的边插值器的可学习位置嵌入。这确保解码器始终处理所有标记,并在训练预测或推理生成中使用相同的可学习位置嵌入,从而提高图像生成性能。
训练
在训练过程中,我们将完整的潜在嵌入
和被掩蔽的潜在嵌入
都馈送到扩散模型中。我们观察到,仅使用被掩蔽的潜在嵌入会使模型过于集中于被掩蔽区域的重建,而忽视了扩散训练。由于不对称的掩蔽结构,使用被掩蔽的潜在嵌入的额外成本是很小的
实验
表1
表2
消融实验
表3
结论
这项工作提出了一种掩蔽扩散变换器,以增强DPMs的上下文表示并改善图像语义之间的关系学习。我们引入了一种有效的掩蔽潜在建模方案到DPMs中,并相应地设计了一个不对称的掩蔽扩散变换器。实验证明,我们的掩蔽扩散变换器在图像合成方面表现出更高的性能,并在训练过程中大大提高了学习进度,实现了在ImageNet数据集上图像合成的新SoTA。