【Medical Image Analysis 1区TOP】用于MRI重建的全局感受野傅里叶卷积块
文章目录
- 一、论文信息
- 二、论文概要
- 三、实验动机
- 四、创新之处
- 五、实验分析
- 六、核心代码
- FCB即插即用
- 七、实验总结
- 八、写作
一、论文信息
- 论文题目:Fourier Convolution Block with global receptive field for MRI reconstruction
- 中文题目:用于MRI重建的全局感受野傅里叶卷积块
- 论文链接:https://www.sciencedirect.com/science/article/pii/S1361841524002743
- 代码链接:https://github.com/Haozhoong/FCB/tree/master
- 作者:Haozhong Sun, Yuze Li, Zhongsen Li, Runyu Yang, Ziming Xu, Jiaqi Dou, Haikun Qi, Huijun Chen
- 单位:清华大学 生物医学工程系,北京,中国,上海科技大学 生物医学工程学院,上海,中国
- 核心速览:本文提出了一种傅里叶卷积块(Fourier Convolution Block, FCB),它能够在频域实现全局感受野,替代传统卷积层,从而显著提升MRI欠采样重建的图像质量。该方法不仅超越了基于CNN的大核卷积和Vision Transformer,还保持较低计算复杂度。
二、论文概要
MRI重建中,欠采样会导致混叠伪影,需要强大的全局感受野模型来去除伪影。然而,传统CNN受限于局部卷积,Vision Transformer虽能捕获全局依赖,但计算量大。本文提出的FCB利用傅里叶变换将卷积转化为频域逐点乘法,在保持低计算开销的同时获得全局感受野。作者将FCB嵌入UNet、MoDL、VSNet、E2EVar等网络,在FastMRI脑部和膝关节数据集上验证,结果显示:FCB显著提升了PSNR、SSIM,并在噪声、不同采样模式下依旧表现稳定。
三、实验动机
-
MRI采集速度慢,欠采样会导致图像中存在全局分布的混叠伪影。
-
CNN虽然效果良好,但受限于局部卷积的有效感受野,难以消除全局伪影。
-
Vision Transformer和大核CNN虽能扩展感受野,但训练困难、计算/显存开销大。
-
因此,作者希望设计一种既具备全局感受野,又能保持高效和低复杂度的新型卷积单元。
四、创新之处
-
提出傅里叶卷积块(FCB):在频域实现卷积,天然具备全局感受野,同时可自适应学习局部与全局特征。
-
轻量化设计:结合深度可分离卷积和捷径结构,降低FCB参数量与计算开销。
-
再参数化训练策略:先用3×3卷积训练局部特征,再转化为FCB进行微调,解决大感受野模型训练难的问题。
-
Plug-and-Play特性:可嵌入任意CNN结构,不依赖特定架构。
-
对比实验全面:与CNN、ViT、大核卷积、频域方法(如FFC、FasterFC)及k-space方法全面比较,表现均优于对手。
五、实验分析
数据集:FastMRI脑部(T1, T2, FLAIR 等)与膝关节数据,分为训练/验证集,采用8×和12×加速率。
指标:PSNR、SSIM,配合Wilcoxon显著性检验。
结果:
-
在四个主流CNN架构中,FCB版本均显著提升性能(p < 0.001)。
-
在不同采样模式(Poisson、Cartesian、Radial)下,FCB均带来提升。
-
在添加10%、20%高斯噪声时,FCB仍保持较强鲁棒性。
-
对比ViT(ViT, ReconFormer, SwinMR)、k-space方法(Ksp, MDRecon)、频域方法(FFC, FasterFC),FCB表现最好,且推理速度优于ViT、大核卷积。
可视化结果:
基线CNN出现伪影、细节丢失,而FCB增强模型能恢复更多细节(骨纹理、脑结构边缘)。
FCB显著扩大了有效感受野,与MRI采样模式的点扩散函数(PSF)更接近。
消融实验:验证了循环padding、深度卷积、shortcut、再参数化等设计的必要性,每一步都显著提升性能。
六、核心代码
FCB即插即用
import torch
import numpy as np
import torch.nn as nnfrom numpy.random import RandomStatedef complexinit(weights_real, weights_imag, criterion):output_chs, input_chs, num_rows, num_cols = weights_real.shapefan_in = input_chsfan_out = output_chsif criterion == 'glorot':s = 1. / np.sqrt(fan_in + fan_out) / 4.elif criterion == 'he':s = 1. / np.sqrt(fan_in) / 4.else:raise ValueError('Invalid criterion: ' + criterion)rng = RandomState()kernel_shape = weights_real.shapemodulus = rng.rayleigh(scale=s, size=kernel_shape)phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)weight_real = modulus * np.cos(phase)weight_imag = modulus * np.sin(phase)weights_real.data = torch.Tensor(weight_real)weights_imag.data = torch.Tensor(weight_imag)class FCB(nn.Module):def __init__(self, input_chs:int, num_rows:int, num_cols:int, stride=1, init='he'):super(FCB, self).__init__()self.weights_real = nn.Parameter(torch.Tensor(1, input_chs, num_rows, int(num_cols//2 + 1)))self.weights_imag = nn.Parameter(torch.Tensor(1, input_chs, num_rows, int(num_cols//2 + 1)))complexinit(self.weights_real, self.weights_imag, init)self.size = (num_rows, num_cols)self.stride = stridedef forward(self, x):#对输入张量做2D实数傅里叶变换(频域表示)x = torch.fft.rfftn(x, dim=(-2, -1), norm=None)#根据复数乘法的基本规则:#频域卷积公式:(a+bi)(c+di)=(ac-bd)+(ad+bc)ix_real, x_imag = x.real, x.imagy_real = torch.mul(x_real, self.weights_real) - torch.mul(x_imag, self.weights_imag) #实部y_imag = torch.mul(x_real, self.weights_imag) + torch.mul(x_imag, self.weights_real) #虚部#对输出进行逆傅里叶变换,得到时域信号 x = torch.fft.irfftn(torch.complex(y_real, y_imag), s=self.size, dim=(-2, -1), norm=None)if self.stride == 2:x = x[...,::2,::2]return xdef loadweight(self, ilayer):weight = ilayer.weight.detach().clone()fft_shape = self.weights_real.shape[-2]weight = torch.flip(weight, [-2, -1])pad = torch.nn.ConstantPad2d(padding=(0, fft_shape - weight.shape[-1], 0, fft_shape - weight.shape[-2]),value=0)weight = pad(weight)weight = torch.roll(weight, (-1, -1), dims=(-2, - 1))weight_kc = torch.fft.fftn(weight, dim=(-2, -1), norm=None).transpose(0, 1)weight_kc = weight_kc[..., :weight_kc.shape[-1] // 2 + 1]self.weights_real.data = weight_kc.realself.weights_imag.data = weight_kc.imagif __name__ == "__main__":# 构建张良 (batch = 1 通道c = 32, 高=50, 宽=50)x = torch.randn(1, 32, 50, 50)model = FCB(input_chs=32,num_rows=50, num_cols=50)output = model(x)print(f"输入张量的形状: {x.shape}")print(f"输出张量的形状: {output.shape}")
通道分组折半卷积模块(Channel Grouping Half-convolution,CGHF):
实际意义:①传统CNN难以针对性提取多层级特征:果蔬图像存在三类关键特征(表面属性如颜色、形状、纹理、深度),无法“差异化”捕捉不同层级特征,导致相似类别难以区分。②全通道卷积导致计算与参数冗余:传统卷积对特征图所有通道进行相同操作,产生大量冗余计算,特征部分通道信息重复,无需全通道密集计算。
实现方式:①将特征图通道划分为3组,每组独立卷积实现对判别性特征的捕获,保证不同特征(颜色、纹理、深层语义)的多样性和完整性。②仅对一半通道进行卷积,另一半保持不变,减少 FLOPs和内存访问量。
七、实验总结
本文通过引入傅里叶卷积块(FCB),为MRI欠采样重建提供了一种兼具全局感受野与计算高效性的卷积设计。实验证明:
-
FCB能有效去除全局分布的混叠伪影。
-
FCB嵌入CNN后普遍提升性能,优于ViT、大核CNN及频域对手方法。
-
该模块具备Plug-and-Play特性,适用于多种CNN架构。
-
在推理效率上,FCB明显优于ViT,运行时间与11×11卷积CNN相当但性能更佳。
局限性在于:参数量仍高于普通卷积,FFT/IFFT重复计算可能影响效率。未来研究可探索纯频域CNN或优化频域激活函数,以进一步降低开销。
八、写作
- 语义分割任务/医学影像分割:①实际问题:病灶目标区域往往很小,且存在组织形态与纹理的多尺度差异;轻量化模型在边缘设备上运行时易掉精度。解决方案:①Channel Grouping:分组学习器官边界(浅层)、组织纹理(中层)、病灶语义(深层)。②Half-Convolution:减少 MRI/CT 图像处理计算量,同时保留关键特征。【研究对象可任意替换】
- 目标检测任务/缺陷检测、遥感识别:①实际问题:工业/遥感场景需实时检测,边缘设备算力不足;缺陷特征分层次(表面划痕-浅层特征、内部裂纹-深层特征、材质纹理-中层特征),传统CNN易漏检。解决方案:①Channel grouping:将特征图通道分3组,分别对应“表面缺陷、材质纹理、内部结构”,每组独立提取缺陷特征,避免不同缺陷特征干扰;②Half-convolution:对每组通道仅50%做卷积,剩余50%用于保留原始特征,降低设备算力压力。【研究对象可任意替换】