【ICLR 2024】MogaNet:多阶门控聚合网络
文章目录
- 一、论文信息
- 二、论文概要
- 三、实验动机
- 四、创新之处
- 五、实验分析
- 六、核心代码
- 源代码
- 七、实验总结
一、论文信息
- 论文题目:MogaNet: Multi-order Gated Aggregation Network
- 中文题目:MogaNet:多阶门控聚合网络
- 发表期刊:ICLR
- 论文链接:点击跳转
- 代码链接:点击跳转
- 作者:Siyuan Li, Zedong Wang, Zicheng Liu, Cheng Tan, Haitao Lin, Di Wu, Zhiyuan Chen, Jiangbin Zheng, Stan Z. Li、李思远,王泽东,刘子成,谭程,林海涛,吴迪,陈志远,郑江斌,李战书
- 单位:浙江大学计算机科学与技术学院,西湖大学未来产业研究中心 AI 实验室。
- 核心速览:MogaNet 提出了一种结合 多阶特征交互 与 门控聚合机制 的卷积神经网络架构。该方法能够显式地捕获 中阶交互特征(比低阶纹理和高阶全局特征更具判别力),同时在计算效率上优于 ViT 与现有 ConvNet。在 ImageNet-1K 上取得 87.8% Top-1 精度(MogaNet-XL),并在目标检测、语义分割、姿态估计和视频预测等任务上全面超越主流模型。
二、论文概要
-
核心思想:通过多阶门控聚合机制结合特征分解与通道重分配,高效建模低阶到高阶的多层次交互特征,从而提升卷积网络的表达能力与泛化性能。
-
研究发现现有 ConvNet 与 ViT 的表示能力在交互建模上存在 偏向极端(过低或过高阶交互) 的问题,导致泛化与判别性不足。
-
MogaNet 通过 多阶门控聚合(Moga Block) 与 通道聚合(Channel Aggregation, CA Block) 来解决该问题,强制网络编码原本被忽略的中阶交互。
-
在多个视觉基准(ImageNet-1K、COCO、ADE20K、COCO Pose、视频预测)上验证了其有效性与高效性。
三、实验动机
-
ViT 依赖自注意力获得全局交互,但计算量大且缺乏局部先验。
-
现代 ConvNet 虽然引入大卷积核,但依然偏向低/高阶交互,无法充分编码中阶语义。
-
因此,研究目标是:在保持计算高效的前提下,有效建模多阶交互,特别是中阶交互。
四、创新之处
- 多阶博弈论视角:论文首次从博弈论角度分析卷积网络的交互行为,指出现有方法过于依赖低阶或高阶交互而忽略中阶交互,从而提出以中阶交互为核心的设计原则。
- 多阶空间聚合:在空间维度引入不同膨胀率的深度卷积并结合 gating 机制,实现对局部细节、语义结构和全局信息的同时建模,提升特征表达的层次性与多样性
- 特征分解(FD):通过特征分解操作显式去除冗余或弱判别特征,突出更有价值的中阶特征,增强特征表征的判别性和鲁棒性
- 多阶通道重分配(CA Block):设计轻量化的通道聚合模块来动态重分配通道资源,高效筛选和利用多阶信息,比传统的 SE 或 MLP 更简洁高效
- 高效网络架构:基于上述模块构建四阶段的 MogaNet 主干网络,在保持计算和参数高效的同时,在分类、检测、分割等多个视觉任务上超越了 ConvNeXt 和 Swin Transformer 等主流模型。
总结:MogaNet 的创新在于用多阶博弈论交互指导架构设计,并通过多阶空间聚合+特征分解+通道重分配 模块,有效捕捉并利用中阶交互,从而在保证高效性的同时,显著提升了卷积网络的表现力和泛化能力。
五、实验分析
-
ImageNet-1K 分类:MogaNet-S 达到 83.4%,超越 Swin-T、ConvNeXt-T 等基线。MogaNet-XL 预训练后达到 87.8%,超过 ConvNeXt-XL 的同时减少 169M 参数。
-
COCO 检测 & 分割:在 Mask R-CNN 与 Cascade Mask R-CNN 上显著提升 AP(+2.3 mAP)。
-
ADE20K 语义分割:MogaNet-L 超越 ConvNeXt-L 和 RepLKNet。
-
姿态估计 & 视频预测:在 2D/3D 人体姿态与视频预测中优于 Transformer 与 ConvNet。
六、核心代码
源代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ElementScale(nn.Module):def __init__(self, embed_dims, init_value=0., requires_grad=True):super(ElementScale, self).__init__() # 调用父类构造函数初始化模块# 初始化一个可训练的缩放参数self.scale,用于对输入张量进行逐元素加权。# 例如,当embed_dims为64且初始值为0.5时,self.scale的形状为(1, 64, 1, 1),所有值均为0.5。self.scale = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), # 初始缩放因子requires_grad=requires_grad # 决定该参数是否参与梯度更新)def forward(self, x):return x * self.scale # 返回输入与缩放因子相乘的结果class MultiOrderDWConv(nn.Module):"""基于膨胀深度卷积实现多阶特征提取模块。参数:embed_dims (int): 输入特征图的通道数。dw_dilation (list): 三个深度卷积层对应的膨胀因子。channel_split (list): 不同分支的通道分配比例。"""def __init__(self,embed_dims,dw_dilation=[1, 2, 3],channel_split=[1, 3, 4],):super(MultiOrderDWConv, self).__init__()# 根据channel_split计算每个分支的通道比例,例如:1/8, 3/8, 4/8self.split_ratio = [i / sum(channel_split) for i in channel_split]self.embed_dims_1 = int(self.split_ratio[1] * embed_dims) # 第二部分通道数:约3/8 * embed_dimsself.embed_dims_2 = int(self.split_ratio[2] * embed_dims) # 第三部分通道数:约4/8 * embed_dimsself.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2 # 第一部分通道数:剩余部分self.embed_dims = embed_dims# 检查参数长度和合理性:确保dw_dilation和channel_split均为3个元素,并且膨胀率在1到3之间,同时embed_dims能被分割比例整除assert len(dw_dilation) == len(channel_split) == 3assert1 <= min(dw_dilation) and max(dw_dilation) <= 3assert embed_dims % sum(channel_split) == 0# 基础深度卷积:对整个输入进行5x5卷积操作,填充值依据膨胀率计算self.DW_conv0 = nn.Conv2d(in_channels=self.embed_dims,out_channels=self.embed_dims,kernel_size=5,padding=(1 + 4 * dw_dilation[0]) // 2, # 根据膨胀因子计算所需填充groups=self.embed_dims, # 分组数等于通道数,实现深度卷积stride=1,dilation=dw_dilation[0], # 膨胀率为dw_dilation[0](通常为1))# 第二个深度卷积:对第一分支(通道占比约3/8)应用5x5卷积self.DW_conv1 = nn.Conv2d(in_channels=self.embed_dims_1,out_channels=self.embed_dims_1,kernel_size=5,padding=(1 + 4 * dw_dilation[1]) // 2,groups=self.embed_dims_1,stride=1,dilation=dw_dilation[1], # 膨胀率为2)# 第三个深度卷积:对第二分支(通道占比约4/8)应用7x7卷积self.DW_conv2 = nn.Conv2d(in_channels=self.embed_dims_2,out_channels=self.embed_dims_2,kernel_size=7,padding=(1 + 6 * dw_dilation[2]) // 2,groups=self.embed_dims_2,stride=1,dilation=dw_dilation[2], # 膨胀率为3)# 逐点卷积,用于融合各分支特征self.PW_conv = nn.Conv2d(in_channels=embed_dims,out_channels=embed_dims,kernel_size=1)def forward(self, x):x_0 = self.DW_conv0(x) # 对整个输入应用第一个5x5深度卷积# 从x_0中截取第二部分通道,输入到第二个深度卷积x_1 = self.DW_conv1(x_0[:, self.embed_dims_0: self.embed_dims_0 + self.embed_dims_1, ...])# 从x_0中截取最后一部分通道,输入到第三个深度卷积x_2 = self.DW_conv2(x_0[:, self.embed_dims - self.embed_dims_2:, ...])# 在通道维度上拼接第一部分、第二部分和第三部分的输出x = torch.cat([x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)x = self.PW_conv(x) # 使用1x1卷积融合拼接后的多分支特征return xclass MultiOrderGatedAggregation(nn.Module):"""实现多阶门控聚合的空间模块。参数:embed_dims (int): 输入特征图的通道数。attn_dw_dilation (list): 三个深度卷积层的膨胀因子。attn_channel_split (list): 各分支的通道分配比例。attn_force_fp32 (bool): 是否强制以FP32计算,默认为False。"""def __init__(self,embed_dims,attn_dw_dilation=[1, 2, 3],attn_channel_split=[1, 3, 4],attn_force_fp32=False,):super(MultiOrderGatedAggregation, self).__init__()self.embed_dims = embed_dimsself.attn_force_fp32 = attn_force_fp32# 第一个1x1卷积,用于初步特征投影self.proj_1 = nn.Conv2d(in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)# 门控分支:生成门控系数self.gate = nn.Conv2d(in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)# 值分支:通过多阶深度卷积提取特征self.value = MultiOrderDWConv(embed_dims=embed_dims,dw_dilation=attn_dw_dilation,channel_split=attn_channel_split,)# 最终融合的1x1卷积层self.proj_2 = nn.Conv2d(in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)# 使用SiLU激活函数分别激活门控和值分支self.act_value = nn.SiLU()self.act_gate = nn.SiLU()# 使用ElementScale对特征进行微调分解self.sigma = ElementScale(embed_dims, init_value=1e-5, requires_grad=True)# ai缝合大王def feat_decompose(self, x):# 通过1x1卷积先进行特征投影x = self.proj_1(x)# 对投影后的特征进行全局平均池化,获得全局统计信息x_d = F.adaptive_avg_pool2d(x, output_size=1)# 利用sigma对原始特征与全局均值的差异进行微调,并将结果加回原始特征中x = x + self.sigma(x - x_d)x = self.act_value(x) # 应用激活函数,此处设计可根据需求调整return xdef forward(self, x):shortcut = x.clone() # 保存输入以便后续残差连接# 蓝色框部分:通过特征分解模块调整特征x = self.feat_decompose(x)# 灰色框部分:分别生成门控系数F和值特征GF_branch = self.gate(x)G_branch = self.value(x)# 分别对F和值特征应用SiLU激活后逐元素相乘,并通过proj_2融合x = self.proj_2(self.act_gate(F_branch) * self.act_gate(G_branch))x = x + shortcut # 添加残差连接以保留原始信息return xif __name__ == '__main__':input = torch.randn(1, 64, 32, 32) # 生成一个随机输入张量,形状为(1, 64, 32, 32)MOGA = MultiOrderGatedAggregation(64) # 实例化多阶门控聚合模块,设定通道数为64output = MOGA(input)print('MOGA_input_size:', input.size())print('MOGA_output_size:', output.size())
七、实验总结
-
MogaNet 在多个视觉任务上表现出 强大的性能与参数效率。
-
相比 ViT,其计算量更低,收敛更快,泛化性更强。
-
相比 ConvNet,能够更好地建模中阶交互,从而提升表现。