多模态图像融合2
论文题目:Multimodal Fusion Learning with Dual Attention for Medical Imaging
论文下载地址:https://arxiv.org/pdf/2412.01248v1
源码下载地址:https://github.com/misti1203/DRIFA-Net
一、主要内容
多模态融合学习在皮肤癌和脑肿瘤等多种疾病的分类中显示出巨大的前景。然而,现有的方法面临三个关键的限制。首先,他们往往缺乏推广到其他诊断任务,由于他们的重点是一个特定的疾病。其次,他们没有充分利用来自不同模式的多个健康记录来学习鲁棒的补充信息。最后,他们通常依赖于单一的注意机制,而忽略了多种模式内和跨模式的多重注意策略的好处。为了解决这些问题,本文提出了一种双鲁棒的信息融合注意机制(DRIFA),该机制利用两个注意模块,即多分支融合注意模块和多模态信息融合注意模块。DRIFA可以与任何深度神经网络集成,形成一个多模态融合学习框架,记作DRIFA- net。本文发现DRIFA的多分支融合注意学习了每种模式的增强表征,如皮肤镜检查、巴氏涂片检查、MRI和ct扫描。而多模态信息融合注意模块学习了更精细的多模态共享表示,提高了网络跨多任务的泛化能力,提高了整体性能。此外,为了估计DRIFA-Net预测的不确定性,本文采用了一个集合蒙特卡罗退出策略。在五个具有不同模式的公开数据集上进行的广泛实验表明,我们的方法始终优于最先进的方法。
二、整体架构
DRIFA与ResNet18集成,创建了一个多分支多模态融合学习网络,称为DRIFA- net(如图1所示)。
图1。DRIFA-Net的详细架构。关键组成部分包括:(A)目标特异性多模态融合学习(TMFL)阶段,其次是(B)不确定性量化(UQ)阶段。TMFL阶段包括一个鲁棒的残差注意(RRA)块,如图(C)所示,并利用多分支融合注意(MFA),一个用于进一步改进局部表示的额外MFA模块,一个用于改进多模态表示学习的多模态信息融合注意(MIFA)模块,以及用于处理多个分类任务的多任务学习(MTL)。在(UQ)阶段,评估DRIFA-Net预测的可靠性。
1、目标特异性多模态融合学习(TMFL)
DRIFA-Net依赖于目标特定的多模态融合(TMFL),追求学习增强的共享多模态表示,以在目标特定的分类任务中实现更好的性能。TFML利用鲁棒的剩余注意(RRA)块,该块集成了本文提出的多分支融合注意(MFA)模块,可以有效地学习各种精炼的局部模式。此外,TMFL还结合了本文提出的多模态信息融合注意(MIFA)模块来学习增强的多模态表示。最后,采用目标特定的多任务学习(MTL)方法,在TMFL阶段同时处理多个分类任务。在下面,本文将讨论RRA, MIFA和MTL块-它们代表了TMFL中感兴趣的突出元素。
(1)RRA:鲁棒的残差注意块
RRA块包含了本文在每个卷积层之后应用的MFA模块,并利用跳过连接策略。这种方法旨在学习不同的局部表示,从而提高本文的学习网络的性能。
多分支融合注意模块(MFA)
图2。(a)多分支融合注意(MFA)模块。关键组件包括用于多种局部信息增强的分层信息融合注意(HIFA)和用于改进特定信道表示学习的通道局部信息注意(CLIA)。
多分支融合注意模块(Multi-branch Fusion Attention Module, MFA)旨在从输入特征中学习增强的局部表示。图2说明了这一点。具体来说,MFA模块集成在网络所有分支的每个RRA块中。另一个MFA模块用于进一步细化这些表示(图 1(a)),从而提高模型学习更详细的局部模式的能力。
为了增强局部信息的获取能力,MFA采用了两个注意模块:
HIFA (Hierarchical information Fusion attention)模块丰富了不同的局部信息,
CLIA (Channel-Wise local information attention)模块压缩了不同的通道信息。
HIFA模块集成到第一个分支中以捕获各种本地特征,而CLIA模块应用于第二个分支中以细化通道信息,如图2所示。此外,采用调制策略选择性地强调输入数据中的关键表征并抑制不相关表征,从而提高学习网络的整体性能。
MFA模块旨在通过将输入特征映射x∈RH×W×C(其中H、W和C分别表示通道的高度、宽度和数量)转换为x′= x⊗a⊗ωc来增强不同的局部表示学习。这里,⊗表示元素智能乘法,a是增强的局部注意图,ωc是通道智能可学习参数,用于在训练期间调整每个通道的重要性。
为了设计HIFA模块,本文使用p- 1x1卷积层表示为ψp, p- 2全球平均池化(GAP)层表示为β, p- 2全球最大池化(GMP)层表示为γ来学习不同的局部信息。这个过程包括四个关键步骤:
首先,通过卷积层和GAP层对输入特征进行处理,获取初始局部信息lp=0。
其次,使用后续卷积层和GMP层对第p层的特征进行细化,提取额外的局部信息lp=1。
第三,将精炼的特征融合并通过进一步的卷积层,每个卷积层后面都有GAP或GMP,以捕获不同的局部信息lp。
最后,将得到的局部信息变体分层融合以获得增强的多样化局部模式。
为了设计CLIA模块,本文使用第q个1x1个具有sigmoid激活函数(σ)的卷积层,然后使用第q个平均池化层来压缩信道信息。此外,使用跳过连接策略将结果信息与最初压缩的信道信息融合,增强了对信道局部信息的学习,从而提高了模型的性能。
最后,为了结合从HIFA和CLIA中学习到的局部信息,本文使用可学习的权重和
来调整每个学习到的局部信息分量的重要度。然后将精炼的信息融合在一起,以增强对不同局部细节的捕获。一个sigmoid激活函数σ被应用于生成注意图A,它通过捕获细粒度的细节来突出关键特征并提高网络性能,如下面的等式所示:
(2)MIFA:多模态信息融合注意模块
图3。(a)多模态信息融合注意(MIFA)模块。该模块包括多模态全局信息融合注意(MGIFA)(如图b所示)和多模态局部信息融合注意(MLIFA)(如图c所示)。
为了学习多模态共享注意特征图A,本文设计了两个注意模块:多模态全局信息融合注意(MGIFA)和多模态局部信息融合注意(MLIFA)。此外,还采用了类似于MFA模块的融合策略。
MGIFA和MLIFA模块都包含各种池层。为了学习不同的全局上下文,本文使用了全局最小池化(α)、全局最大池化(γ)和全局平均池化(β)。而对于学习不同的局部细粒度细节,则使用最小池化τ、最大池化τ和平均池化δ。
为了增强系统对不同全局和局部信息的学习能力,设计了多模态全局和局部信息融合方法。具体而言,该方法将一种模态的每个池化层与其他模态的相应相似池化层融合在一起,以在全局和局部上下文中学习互补信息。由此产生的互补信息增强了本研究的学习网络模型的每个分支中所有模式的学习,从而增强了更好的性能。例如,为了增强从全局平均池化中获得的全局信息,将从模态1中学习到的全局信息与其他m−1模态的全局信息融合。该策略统一应用于从各自池化层学习到的所有信息,旨在增强全局和局部信息的多样性。进一步,对每个结果信息应用一个完全连接的层,然后对所有结果信息进行融合,得到丰富的全局g’和局部表示l’,如下式所示:
与MFA模块类似,为了将从MGIFA和MLIFA中学习到的信息(例如,和
)组合在一起,我们使用可学习权重(
和
)来调整这些信息的重要性,从而改进模式。在多模态融合学习设置中,融合操作(加法)之后的sigmoid激活σ生成多模态共享注意图A,捕获不同的全局上下文和细粒度细节。
(3)MTL:目标特定的多任务学习
在MTL阶段,本文利用来自TMFL阶段的共享表征跨越m种不同的医学成像模式。这通过学习鲁棒的互补信息增强了DRIFA-Net的泛化能力,从而改进了对多个模态特定测试集的预测。MTL方法利用DRIFA-Net θ(·)来映射输入特征[
,…,
]从m个模态到t个分类任务[
,…,
]。MTL损失函数
结合了特定于任务的交叉熵损失
,定义如下:
其中θ(,
) = [
,…,
]→[
,…,
],
表示每个任务特定交叉熵损失的权重因子,确保有效的任务性能平衡。
2、不确定性量化(UQ)
本文使用集合蒙特卡罗退出策略评估DRIFA-Net中的预测不确定性。这种方法通过平均z个集合模型的随机预测来计算软概率,每个模型都利用随机抽样dropout掩模(λ)将随机性引入DRIFA-Net。本文的方法涉及e = 20次蒙特卡罗采样迭代,通过θ(·)生成不同的预测。在测试过程中,DRIFA-Net在多个模态特定的测试集上执行了20次,从这些运行的平均结果中推断出预测的不确定性,如下式所示:
其中Ω表示softmax分类器。
三、总结
本文提出了一种双信息融合注意方法来增强多模态融合学习,使其适用于不同医学成像模式的疾病分类任务,如宫颈癌、皮肤癌、肺癌和脑肿瘤。通过将多分支融合注意和多模态信息注意模块相结合,超越了现有的最先进的方法。未来的工作将集中在扩展我们的方法,以获得更多的医学成像模式和优化计算效率。