MORA:LORA引导缺失模态多模态疾病诊断 !

2024-09-10 20:54:01 浏览数 (3)

多模态预训练模型能有效地从不同模态中提取和融合特征,其内存需求低,便于微调。然而,尽管具有高效性,但这些模型在疾病诊断中的应用尚未充分探索。 一个重要的挑战是,缺失的模态频繁出现,这会损害性能。此外,整个预训练模型的微调需要大量的计算资源。 为了解决这些问题,作者提出了一个计算高效的方法--即模态感知低秩自适应(MoRA) 。MoRA将每个输入映射到低内在维度,但在出现缺失模态的情况下,使用不同的模态感知上投影进行模态特定适应。 实际上,MoRA集成到模型中的第一块,在一种模态缺失的情况下可以显著提高性能。与训练整个模型相比,它只需要小于1.6%的可训练参数。 实验结果显示,MoRA在疾病诊断方面优于现有技术,表明具有优越的性能、鲁棒性和训练效率。 代码链接:https://github.com/zhiyiscs/MoRA。

1 Introduction

多模态预训练模型在通用的计算机视觉任务,包括分类和回归领域取得了巨大的成功[1, 2, 8]。在广泛的多样数据集上的预训练,使得多模态预训练模型能够理解不同模态(如图像、文本、音频和视频)之间的复杂模式和关系。此外,预先存在的知识减少了在采用这些模型作为下游任务时大量特定数据的需要。

近年来,研究行人通过在大型医学数据集上训练多模态模型,将预训练模型引入医疗领域[3, 4, 13]。然而,将这些模型应用于实际临床环境中的疾病诊断存在两个主要挑战。首先,实际疾病诊断中的情况相当常见,即患者的X光片(例如胸X光片)图像完整,但是相应的标注有缺失。然而,实验证明,在缺少模态的情况下,多模态预训练模型的性能急剧下降[8]。其次,大多数预训练模型基于大规模的 Transformer 模型,因此整个预训练模型的微调仍然非常昂贵。

大部分相关的先驱研究集中改进了模型结构,但这种方法无法直接应用于预训练模型微调。许多研究[10, 5]也采用插值,即根据其他完整的模态填补缺失的模态,输入。但是,当模态数量相对较少(例如两个或三个模态)时,插值非常不健壮,可能会导致结果恶化。对于微调多模态预训练模型,Lee等人[6]首先引入了多模态提示的概念,它使用MAPs(即在使用缺失模态时提高性能的提示)来提高训练和测试集中缺失模态时的性能。然而,MAPs在不同训练和测试中的缺失模态设置之间缺乏稳健性。在MAPs之上,Jiang等人[7]提出了特定的模态提示(MSPs),相对于MAPs,它们对不同的缺失设置更为稳健。然而,MSPs仍需要插到多个层才能达到最佳性能。

受到低秩自适应(LoRA)[9]的启发,作者提出模态感知低秩自适应(MoRA),以提高在面对数据集训练和测试集中缺失模态时的性能和健壮性。具体而言,MoRA通过将每个输入映射到低内生维数,同时使用不同的模态感知上项目来获得模态感知的自适应。这些自适应能够识别每个模态的独特特征,从而在某些模态缺失的情况下增强模型的健壮性和性能。与现有的微调方法相比,MoRA的一个关键优势在于其实现效率。它只需要集成到模型初始模块中,就可以在处理缺失模态时导致显著的增强。在微调过程中,所有需要训练的参数仅为MoRA和分类器。这样,作者的方法将训练参数限制为总模型参数的1.6%,允许模型在微调较小数据集(数千样本)时达到更好的性能。作者的实验结果证明了MoRA的有效性,该方法在不仅比现有的方法具有更高的准确性和健壮性,而且提高了训练效率。

作者的主要贡献是:

  • 作者将多模态预训练模型引入疾病诊断,并提出了MoRA来改善在训练和测试集中数据缺失时的性能和健壮性。
  • 与采用缺失模态的其他微调方法相比,作者的方法实现了最先进的性能和健壮性。
  • 作者进行了包含不同缺失率的模态的全面实验,以证明具有不同模态缺失比例时,作者的方法具有优越的性能和健壮性。

2 Method

Problem Definition

为了简化,作者将以两个模态表示的疾病诊断作为一种情形来解释作者的方法。例如,可以表示为和(例如,图像和文本)。作者将这个数据集表示为。在这里,表示同时存在两种模态的部分,称为模态完整子集。相反,和表示模态不完整的子集,如仅图像或仅文本的患者,其中一种模态是缺失的。如图1所示,数据集包括完整患者(表示为),仅文本患者(表示为)和仅图像患者(表示为)。

为了保留多模态输入的格式以便在多模态预训练模型中进行多模态,作者只是将空字符串或像素(例如,对于文本或图像)分配给缺失模态的病人,并生成,。因此,整个患者数据集可以被改革为。

Modality-Aware Low-Rank Adaptation

低秩自适应在大语言模型微调中得到了广泛应用。其主要机制是将预训练模型的权重冻结,并将可训练的秩分解矩阵注入到预训练模型中。LoRA理论认为,在适应时,权重更新的内在秩较低。对于预训练权重矩阵,LoRA通过用低秩分解来限制其更新,其中,,且秩。在训练期间,被冻结,不接收梯度更新,而和包含可训练参数。请注意,和与相同的输入相乘,并分别对它们的输出向量进行对应坐标的求和。对于,改进的前向传播结果如下:

其中和分别表示输入和输出特征。作者将视为对输入的适应(称为)。在实际中,秩总是设置为一个较小的数字(例如,4),因此LoRA可以在有限计算资源内进行训练。

在LoRA之上,作者提出了一个考虑模态意识的LoRA,它与LoRA的区别在于引入了模态意识的适应。作者使用一个单一的下行投影将所有输入投影到低秩维度,得到低秩特征。对于每个模态,作者分别指定一个特定的模态意识上行投影(记作和)。在获取和后,MoRA根据缺失情况计算适应性。具体来说,如果患者有某模态的数据,MoRA会将其对应的添加到适应性中,反之亦然。因此,对于子集,其相应的模态意识适应如下:

其中,,。选定的适应性将被插入到多模态预训练模型的第一个块中,以提高对缺失模态的鲁棒性。在初始阶段,作者使用随机高斯初始化,并用零初始化和,因此训练开始时适应性为零。

Overall Framework

遵循[8, 6, 7]中的实现,作者使用多模态预训练 Transformer ViLT [1]作为作者的基础模型,该模型旨在处理两种模态:图像和文本。作者方法的结构如图1所示。可训练参数用火焰表示,而固定参数用柜子表示。患者具有不同缺失模态的图像和文本。对于缺失的模态,作者使用一个占位输入(对于缺失的文本,它是空字符串;对于缺失的图像,它是零矩阵)。这用于保持预训练模型输入标记的总数。作者使用固定的预训练嵌入过程将数据转换为输入标记。作者在预训练模型的前一个块(ViLT中的 Transformer 块)中实际应用了MoRA。

3 Experiments

Datasets

胸部X光片(CXR)数据集[20] 来自于印第安纳大学的开放数据源。在这个数据集中,有3794名患者的胸部X光片图像、相应的标注和多位专家诊断出的多种疾病。总共有120种不同的疾病,作者选择了出现最频繁的前20种作为诊断目标。请注意,这个数据集中包含两个X光片图像投影:正面和侧面。在本论文中,作者主要专注于正面投影。

眼科疾病智能识别(ODIR)数据集[21] 来源于旨在反映从医院收集的实际患者集合的眼科数据库。它包括3500名患者的数据,特别 curated,以帮助诊断眼科疾病。这个数据集涵盖了各种模态,包括人口统计信息,双眼的临床文本标注和每只眼睛的眼底图像。

作者的数据集的预测疾病划分和类型如表1所示。

Implementation Details

作者的代码主要基于PyTorch,并使用PyTorch Lightning进行训练和测试推理封装。所有实验均在NVIDIA RTX A4000 GPU上进行。考虑到作者的模型同时预测多种疾病,作者为每个疾病设定了一个单独的二进制交叉熵损失。

对于MoRA和所有 Baseline 方法,作者采用相同的设置进行比较性能。作者将ViLT的所有参数冻结,并采用相同的可训练分类器(包括两个线性层)。作者使用AdamW优化器进行训练,批次大小为4,权重衰减为2e-2。作者将最大学习率设置为5e-3,学习率在总训练步骤的2%处进行 Warm up ,然后线性减小到零。作者使用相同的训练、验证和测试拆分对每种模型进行训练,并为每种模型进行40个周期。如果在5个周期内结果没有改进,则训练将提前终止。作者使用F1-Macro分数来评估多疾病预测的性能。

Comparisons with the previous method

在这部分,作者在训练数据集和测试数据集中进行不同的缺省设置实验,以比较MoRA与三种先前的方法。表2中显示了F1-Macro实验结果。可以观察到,即使在训练数据集和测试数据集中的缺失率不同的情况下, MoRA在大多数缺失场景中实现了最佳结果。值得提及的是,根据原始MSP和MAP文章的设置,作者将它们插入到第1至第6个块中,而作者将MoRA仅插入到第1个块。结果表明,MoRA可以通过插入第1个块实现更好的性能。也可以从表中看到,该模型与文本的鲁棒性显著弱于图像。这在实际多模态学习中是合理的:一种模态的重要性大于其他模态。因此,提高这种重要模态的鲁棒性至关重要。从表中可以看出,当文本严重缺失时,MoRA的性能明显更好。

作者还比较了不同方法在训练过程中的GPU内存需求和训练时间。如表3所示,在1000次训练步骤时,MoRA需要的GPU内存相对较小,训练时间较短。这是因为MoRA只需要插入预训练模型的第一层,导致可训练参数较少。

Ablation Study

对不同缺失模式的鲁棒性: 作者进一步进行了实验,以分析作者提出的 method 在不同缺失模式率下的鲁棒性。为了澄清,作者对每个模态的缺失率保持相同,并认为总缺失率是 。作者在 ODIR 上训练 MoRA,其中 ,这意味着 65% 图像模态和 65% 文本模态样本。作者在不同的缺失率下进行测试,并在图2 中展示了结果。当缺失率较小时,作者的方法和 Baseline 结果没有显著不同。随着缺失率继续增加,作者的模型表现出更大的鲁棒性。这表明作者的模型在极端模态缺失情况下更能应对。

已插入块的影响: 根据 [6]。MAPs 对已插入的块非常敏感。作者还进行了实验,以分析插入位置对 MoRA 的影响。作者在 ODIR 上训练 MoRA,其中 65% 图像模态和 65% 文本模态样本,但固定了秩 r。作者试图将 MoRA 插入到不同的块中,以检查性能。根据表4 的实验结果,实验表明插入到几个块中的性能接近插入到第一个块中。可以看出,与 MAPs 相比,MoRA 对插入的块的数量不是非常敏感。然而,在实验中,作者发现插入到第一个块对于 MoRA 的有效性至关重要。这可能是因为第一个层次可以直接获取输入 Token 的信息,这对 MoRA 确认缺失模块的状态和促进后续微调非常有帮助。因此,在实际使用中,MoRA 最适合插入到第一个块,以便进行微调,这可以使用尽可能少的训练参数来实现良好结果。这也是 MoRA 与 MAPs 的优势。

秩 r 对性能的影响: 作者检查秩 r 对性能的影响。作者在 ODIR 上训练 MoRA,其中 65% 图像模态和 65% 文本模态样本,但设置了不同的秩 r。作者将在表5 中展示结果。正如表格所示,秩的增加使得性能得到改善。然而,结果表明当秩设置为 4 时,性能达到最佳。作者还测试了一个极端情况,即秩 r 等于输入 Token 的维数。在这种情况下,结果非常糟糕,甚至比没有 MoRA 的情况还要糟糕。这表明 MoRA 只有在秩非常小时才能发挥作用,这与 LoRA 的推导相一致。总体而言,MoRA 对 r 的选择不是很敏感。

4 Conclusion

在这篇论文中,作者提出了一种多模态预训练模型用于疾病诊断。

为了解决这些挑战,作者提出MoRA用于微调具有缺失模态的多模态预训练模型。

MoRA将每个输入映射到相同的低内在维度,但利用不同的模态感知上投影来获得针对特定模态缺失情况的模态感知适应。

作者在两个具有不同模态缺失率的疾病诊断任务上进行了实验,结果表明了作者的方法的优势。MoRA不仅提高了 robustness 和性能,还能节省计算资源。

在未来的工作中,作者将扩展作者的方法到更大的预训练模型,并探索将大型多模态预训练模型引入疾病诊断的可行性。

参考

[1].MoRA: LoRA Guided Multi-Modal Disease Diagnosis with Missing Modality.

0 人点赞