CNN 与 Transformer 的强强联合:AResNet-ViT在图像分析中的优势 !

2024-08-09 12:03:38 浏览数 (2)

作者针对残差CNN分支的注意力引导设计进行了消融实验。同时,作者还分别对CNN分支和Transformer分支进行了架构消融实验,以及将两个分支结合使用的实验。 此外,作者将提出的AResNet-ViT网络与经典分类模型的性能进行了比较,并对比了过去三年内发表的三篇论文的结果。 实验结果表明,AResNet-ViT网络以其结合CNN和Transformer的结构以及多注意力机制,在消融实验和对比实验中均取得了最高的评估指标值,包括ACC、TPR、TNR和AUC,这些值分别为0.889、0.861、0.896和0.925。 本研究指出,CNN和Transformer网络的融合可以有效提高分类模型的性能,为超声图像中乳腺结节的良恶性分类提供了一个鲁棒且高效的解决方案。

1 Introduction

乳腺结节,可能表现为囊性或实性肿块,在乳腺组织中经常遇到,是女性中的一种常见病症。这些结节被分为良性或恶性。良性乳腺结节不会对健康造成重大风险,而恶性乳腺结节则表明存在癌性增殖,从而对女性的整体身心健康构成重大威胁。

定期进行乳腺筛查,包括乳房X光摄影、乳腺超声和乳腺磁共振成像(MRI),在早期发现乳腺结节和诊断乳腺癌方面扮演着关键角色。尽管乳房X光摄影存在较大的辐射暴露和有限的成像角度,但它主要用于进一步筛查恶性结节。另一方面,MRI成像耗时长且费用高昂,不适合常规门诊检查。超声成像具有无辐射、成本低、便捷、快速以及在多角度成像方面的灵活性,已成为评估乳腺结节的主要手段[1]。然而,超声对乳腺结节的诊断准确性在很大程度上依赖于超声医生的临床经验。因此,医生经验水平的差异或视觉疲劳的影响常导致误诊或漏诊。

随着人工智能技术的不断发展,研究行人广泛探索了计算机辅助的超声乳腺结节诊断。他们的工作集中在开发智能算法,这些算法能够自动识别并区分超声图像中的结节区域为良性或恶性。这些算法利用深度学习和机器学习等技术来训练结节识别和分类的模型。这种AI辅助的诊断方法有望提高基于超声的乳腺结节评估的准确性和效率,为临床医生提供可靠的辅助工具,以支持临床决策和治疗计划。相反,某些恶性结节可能具有清晰的边界和小于1的纵横比,这些特征通常与良性结节相符,给AI识别带来困难。

在过去的十年中,基于深度学习的方法在自然图像分类中取得了显著的成功,并在医学图像识别领域引起了广泛关注。特别是在超声乳腺图像分类和识别领域,一些研究已经采用了基于CNN的深度学习模型来学习和提取超声图像中乳腺结节的特定特征。2016年,Huynh等人[2]使用ImageNet数据集对VGGNet、ResNet和DenseNet进行预处理,随后比较了这些网络在乳腺超声图像上的分类性能。2017年,Han等人[3]利用GoogLeNet算法来区分良性和恶性的超声乳腺结节。

布莱尔等人[4]在2018年将匹配层引入到预训练的VGG19网络中,旨在增强像素强度并提高乳腺结节分类的性能。2019年,陈思文等人[5]采用自适应对比度增强(ACE)方法进行预处理,并部署了AlexNet模型来区分乳腺结节的良恶性。齐等人[6]采用具有多尺度 Short-Cut 的深度卷积神经网络,以区分超声乳腺恶性结节和实性良性结节。2020年,庄等人[7]利用图像分解得到模糊增强和双边滤波的图像,丰富了乳腺病变的输入信息,并促进了乳腺超声图像的分类。曹等人[8]提出了一种噪声滤波网络(NFNet)用于结节分类。他们引入了双重softmax层以解决由于人工标记错误或数据质量问题导致的不准确标记问题。2021年,卡拉夫等人[9]使用带有注意力机制的VGG16模型来分类乳腺结节的良恶性,并结合二元交叉熵和双曲余弦损失来提高分类性能。萨克塞纳等人[10]利用一个增强的、包含12,000张图像的数据集来比较不同方法在乳腺结节分类中的表现。2022年,卢等人[11]利用预训练的ResNet18结合空间注意力,并结合三种不同的循环神经网络(RNNs)来预测乳腺结节的良恶性。康等人[12]提出了一种多分支网络,包括特征提取子模块、分类子模块和像素注意力子模块,通过注意力机制提高乳腺结节良恶性的分类。

尽管深度卷积神经网络(CNNs)与传统分类方法相比在性能和有效性方面取得了显著进展,但CNNs主要适合提取局部特征,可能在提取全局特征方面存在困难。2020年,多索夫茨基等人[13]提出了视觉 Transformer (ViT)网络,该网络利用自注意力机制提取全局特征,在图像分类任务中表现出卓越的性能。2021年,贝赫纳兹等人[14]采用ViT模型对超声乳腺结节进行分类,与卷积神经网络相比取得了更优的结果。这项研究强调了ViT模型在学习乳腺超声图像分类的全局特征方面的有效性。随后,其他研究也对原始ViT网络进行了改进,专门为乳腺超声图像分类定制[15-18],例如在2023年,谢里夫,B.[18]提出了一种混合多任务深度神经网络,称为Hybrid-MT-ESTAN,该网络结合了CNNs和Swin Transformer进行超声乳腺肿瘤的分类和分割。

超声图像中的局部特征捕捉结节的具体细节和特征,而全局信息和依赖关系反映了结节与周围组织之间的关系和区别。为了充分利用卷积神经网络(CNN)在提取局部特征方面的优势以及视觉 Transformer (Vision Transformer)在提取全局特征方面的能力,本研究提出了将CNN与Vision Transformer结合构建分类网络模型。

本研究的主要贡献概括如下:

(1) 提出的双分支网络架构,命名为AResNet-ViT,无缝整合了CNN和Transformer,以利用局部和全局特征信息,从而显著提高了分类模型的性能。

(2) 设计了用于局部特征提取的注意力引导残差网络(AResNet),旨在捕捉结节的形状、纹理、边缘和高级语义特征。

(3) 利用视觉 Transformer (ViT)捕捉超声图像中像素间的全局依赖关系,使得能够为结节图像生成全面的全局特征表示。

2 Method

AResNet-ViT的双分支架构如图1所示,包含两个分支。网络的上方分支采用由多个注意力引导的残差网络,有效捕捉乳腺结节的局部细节和纹理特征。这种能力提高了对结节内部微小变化的敏感性,有助于准确判定结节的良恶性。另一方面,网络的下方分支采用基于多头自注意力的视觉 Transformer (ViT)来捕捉结节的整体形状、边界以及结节与周围组织的关系,增强了对结节本身和整体图像特征的理解。通过结合并编码从局部特征提取分支和全局依赖特征提取分支中提取的特征,网络能够有效利用局部和全局信息,提高乳腺结节分类的准确性。网络的每个分支输出一个一维特征,随后将这些特征进行拼接,并由全连接的多层感知机(MLP)进行编码。最后,通过Sigmoid激活函数获得分类结果。

Local feature extraction

为了提高网络关注并学习超声乳腺结节内部特征的能力,作者提出了一种名为AResNet的局部引导注意力基础残差网络,作为局部特征提取分支。该架构基于ResNet18框架构建,包含四个残差块,每个块都融入了注意力机制,如图1所示。在残差块1和2的结构中,网络强调超声图像中如纹理和边缘等复杂细节。鉴于图像尺寸较大且复杂细节丰富,融合空间注意力机制变得至关重要,以帮助网络有效捕捉和理解结节内部信息。超声乳腺结节的分割 Mask 提供位置信息,并可作为空间注意力的引导。因此,在残差块1和2中,作者引入了超声乳腺结节分割 Mask 注意力(ROI-mask注意力,RA)[19]。

其中,代表输入,F(x)表示通过卷积块学习的特征,R(x)表示结节 Mask 特征图,Y(x)表示在分割 Mask 注意力指导下输出的学习特征,而C用于匹配残差块和分割 Mask 图的维度。

残差块3和4基于来自残差块1和2的信息进一步提取高级语义特征。这些块中的每个输出通道代表一个独特的高级语义表示,对整体高级语义的贡献各不相同。因此,如图2所示,在残差块3和4中采用了通道注意力(CA)模块,以增强网络对通道输出的关注并放大信息丰富的通道表示。CA模块对输入特征图进行全局平均池化和全局最大池化操作。两种池化操作得到的的一维特征向量随后被合并,并使用多层感知机(MLP)进行编码。然后,将编码结果通过Sigmoid激活函数获取代表每个通道权重的向量。该向量与输入模块的深层特征进行逐元素相乘。该模块的主要目的是为每个通道分配不同的权重,从而放大能有效捕捉超声乳腺结节所展现的高级语义特征的通道特定信息。

Global feature extraction

卷积神经网络(CNNs)主要强调局部感受野进行信息过滤,但在处理超声乳房图像时忽视了全局像素 Level 的自相关性。为了增强网络获取全面全局上下文信息的能力,本研究引入了一种视觉 Transformer (ViT)网络,该网络利用了多头自注意力机制。如图1的下方分支所示,ViT网络提取全局图像特征和像素 Level 的自相关性。该网络由12个 ConCat 的 Transformer 块组成。每个 Transformer 块独立执行自注意力和前馈神经网络操作,以迭代地从输入序列中提取特征。这种设计使得模型可以在不同层次上进行多次自注意力和特征提取的迭代,从而增强了模型的表现力和性能。

全局特征提取的过程如下:首先,将大小为224x224的输入图像划分为16x16大小的块。每个图像块通过线性映射转换为一维向量,并通过添加位置编码来保留块之间的空间信息。带有位置编码的数据随后被送入 Transformer 块进行逐层操作以执行特征编码。通过结合自注意力机制,网络能够捕获图像中不同位置之间的相互依赖关系,有助于全面理解整个图像的上下文。这进而提高了网络对超声乳房图像的整体特征和相关性提取的能力。

Loss function and evaluation metrics

2.3.1 Loss function

由于乳腺超声图像分类是一个二元分类任务,作者采用了二元交叉熵(BCE)损失函数,如公式2所示。

实验硬件环境如下:配备了56核的Intel Xeon(R) CPU E5-2680 v4 @ 2.40GHz处理器,两张NVIDIA GeForce RTX 2080Ti GPU显卡,每张显卡具有11GB的视频内存。

作者设置了自适应矩估计优化器(Adam)的训练参数,学习率为0.0001,批量大小为4。为了防止过拟合,作者采用了早停法。具体来说,如果在验证数据集上的损失函数连续20次迭代未降低,训练将被停止。

Evaluation Metrics

所有实验的性能均通过准确率(ACC)、真正率(TPR)、真负率(TNR)和曲线下面积(AUC)进行评估。准确率提供了对模型分类性能的整体评价。真正率代表了将恶性结节正确分类为恶性的概率。真负率代表了将良性结节准确标记为良性的概率。曲线下面积测量的是接收者操作特征(ROC)曲线下的区域,其中真正率在垂直轴上,假正率(FPR)在水平轴上。AUC值介于0到1之间,值越高表示分类性能越好。这些评估指标在公式(3)至(5)中定义。

其中,TP表示真实标签为乳腺病变并被分类为乳腺病变的像素数量;TN表示真实标签为非乳腺病变并被分类为非乳腺病变的像素数量;FP表示真实标签为非乳腺病变但被分类为乳腺病变的像素数量;FN表示真实标签为乳腺病变但被分类为非乳腺病变的像素数量。

所有实验的性能评估采用了准确率(ACC)、真正率(TPR)、真负率(TNR)和曲线下面积(AUC)。准确率全面反映了模型的分类表现。真正率表示将恶性结节正确识别为恶性的概率。真负率则表示将良性结节准确标记为良性的概率。曲线下面积(AUC)是接收者操作特征(ROC)曲线下的区域,真正率在纵轴上,假正率(FPR)在横轴上。AUC的值介于0到1之间,数值越高,分类性能越好。这些评估指标在公式(3)至(5)中给出定义。

其中,TP是指被正确分类为乳腺病变的真实乳腺病变像素数量;TN是指被正确分类为非乳腺病变的真实非乳腺病变像素数量;FP是指被错误分类为乳腺病变的真实非乳腺病变像素数量;FN是指被错误分类为非乳腺病变的真实乳腺病变像素数量。

Ablation experiments

3.3.1 Effectiveness of the attention mechanism

为了验证注意力引导模块的合理性和有效性,进行了五组消融实验,相应的结果如表1所示。"网络1"指的是未添加任何注意力的ResNet18网络。"网络2"在ResNet18网络的前两个残差块完成后加入分割 Mask 注意力,而"网络3"在最后两个残差块完成后加入分割 Mask 注意力。"网络4"在ResNet18网络的所有残差块完成后集成分割 Mask 注意力。最后,"网络5"在"网络2"的基础上,进一步在最后两个残差块完成后加入通道注意力。所有实验使用同一组参数。

3.3.2 双分支架构的有效性

为了评估双分支架构中每个单独分支以及组合架构在超声乳腺结节分类中的性能,针对四个实验组进行了消融实验。第一组仅使用ResNet18网络进行分类,第二组采用ViT网络对良性及恶性乳腺结节进行分类。第三组使用ResNetA网络,在ResNet18网络的前两个残差块后加入分割 Mask 注意力机制,并在最后两个残差块后加入通道注意力,进行乳腺结节分类实验。第四组和第五组基于ViT网络架构,并并行融合ResNet网络和AResNet网络以分类良性及恶性乳腺结节。消融实验的结果如表2所示。

从表2可以看出,单一网络(ResNet18或ViT)的性能指标低于ResNet18和ViT的组合(ResNetViT网络),这说明网络集成可以学习到更与乳腺结节相关的特征。此外,与ResNet18分类网络相比,AResNet分类网络在准确度(ACC)、真正率(TPR)、真负率(TNR)和曲线下面积(AUC)上分别提高了0.061、0.072、0.062和0.066,这表明AResNet分类网络在引导和学习结节区域特征方面表现更优。同时,AResNet和ViT的并行融合进一步提升了性能指标,尤其是在TNR上改进最为显著。这表明AResNet-ViT模型在识别表现为恶性但实际上是良性的样本方面具有更高的识别能力,这对于临床诊断至关重要,因为它们是最容易误判的情况。

The heat-maps of classification results

图3展示了使用AResNet-ViT模型获得的测试样本的视觉分类结果,其中顶部一行显示原始的超声乳腺结节图像,底部一行呈现的是AResNet-ViT模型生成的特征注意力 Heatmap 。特征注意力 Heatmap 为输入数据的每个位置分配权重,指示模型更关注的区域或特征。这种可视化使作者能够识别输入数据中模型认为最重要的特定区域。从图中可以看出,结节区域受到了模型的主要关注,这体现在 Heatmap 中的高权重区域。

此外,在乳腺超声图像中,当结节内部的超声特征与周围组织相似时,模型能够准确区分结节区域与背景。同时,对于具有重叠表现的良性及恶性结节样本,AResNet-ViT模型的预测结果与金标准相符,表明模型能够实现精确分类。

对比分析

为了探究AResNet-ViT是否优于现有经典模型以及该领域内其他发表的方法,作者进行了对比分析。分析分为两部分:首先与四个已确立的经典模型(VGG16 [21],ResNet34 [22],DenseNet [23],InceptionV3 [24])进行初步比较,随后与三种近期发表的方法进行比较。除参考文献[25]中使用的数据集外,包括本研究提出的方法在内的所有其他方法均使用相同的BUSI数据集。

从表3的前四行可以看出,与经典模型相比,作者的模型在分类结果上表现最为突出,这表明在预测结节的良性或恶性性质方面,作者的模型优于经典模型。具体而言,作者的分类模型显示出更高的真阳性率(TPR),表明其能够识别更多的病变区域,且漏诊率较低。此外,与经典模型相比,作者的模型显示出更高的真阴性预测值(TNP),范围在0.054至0.107之间。这意味着作者的模型在分类具有恶性结节特征但实际上为良性的挑战性样本时,表现出更高的准确性。

表3的第5-7行展示了与其他文献的比较。文献[26]引入了基于Transformer网络的额外嵌入方法来提高分类性能,但作者的方法在所有指标上都表现出更优越的性能。文献[27]采用了双通道输入,提取并融合了不同模态的超声乳腺结节图像和乳腺X射线图像的特征。虽然在真阳性率(TPR)和曲线下面积(AUC)上略优于作者的方法,但其结节分类的准确度(ACC)较低。此外,与文献[25]相比,作者的方法在各项指标上也表现出更佳的性能,该文献本身也承认在分类具有相似良性和恶性表现形式的挑战性样本时结果不理想。总之,作者提出的AResNet-ViT网络在四个评估指标中的准确度(ACC)、TPR、真阴性率(TNR)和AUC方面取得了最高性能。

4 讨论

在本研究中,作者提出了一种名为AResNet-ViT的混合CNN-Transformer架构,用于乳腺超声图像中乳腺结节的良恶性分类。AResNet-ViT模型结合了CNN提取局部特征的能力和Transformer建模全局特征的能力,从而实现了更具鉴别性的特征表示以进行准确分类。AResNet-ViT设计有一个双分支架构。其中一个分支专注于使用基于ResNet18框架的残差网络从图像中提取局部细节特征。这一分支包含四个残差块,每个块都融入了注意力机制。另一个分支利用视觉Transformer(ViT)进行全局特征提取。作者在残差网络的浅层和深层模块中分别使用分割 Mask 注意力和通道注意力,因为残差网络的浅层主要提取低级语义特征,更多关注结节位置信息,而深层残差网络提取高级语义特征,通道权重比结节位置更重要。在残差网络上进行的消融实验验证了同时使用这两种类型的注意力比单独使用分割 Mask 注意力或通道注意力能获得更高的评价指标。

为了评估不同架构在超声乳腺结节分类中的性能,进行了消融实验。研究比较了ResNet18、ViT、AResNet(带有分割 Mask 和通道注意力的ResNet18)以及AResNet与ViT的融合。结果显示,ResNet18与ViT的组合(ResNetViT)优于单个网络,表明了网络整合的优势。AResNet在准确度、真阳性率、真阴性率和曲线下面积方面均优于ResNet18。AResNet与ViT的融合进一步提升了性能,特别是在识别具有恶性特征但实际上为良性的样本方面,这对于准确的临床诊断至关重要。

热力图显示了分类结果,表明AResNet-ViT模型即使在结节内部超声特征与周围组织相似的情况下,也能准确区分结节区域与背景。这进一步证实了AResNet-ViT模型学习和识别结节区域特定特征的能力,显示出其精确的分类性能。与传统的模型以及超声乳腺结节分类领域近期发表的方法相比,作者的AResNet-ViT模型在所有评价指标上均表现出色,包括准确率(ACC)、真正率(TPR)、真负率(TNR)和曲线下面积(AUC),分别为0.889、0.861、0.896和0.925。结果表明,卷积神经网络与 Transformer 混合架构能显著提高超声乳腺结节的分类效果。此外,在卷积阶段集成注意力机制能增强局部特征的提取。

尽管作者的方法性能卓越,但仍存在一些局限性,例如,有效训练需要大量数据集。乳腺超声图像复杂,个体间差异很大,这使得在有限数据情况下构建健壮的分类器变得具有挑战性。未来的工作应集中收集更大且更多样化的数据集以提升模型的泛化能力。另外,混合模型的计算复杂度可能较高,这使得在实时临床环境中部署变得困难。因此,未来的工作还应优化模型以提高计算效率,减少推理时间是实际应用中至关重要的。

参考

[1].AResNet-ViT: A Hybrid CNN-Transformer Network for Benign and Malignant Breast Nodule Classification in Ultrasound Images.

0 人点赞