mmseg的decode_heads解析:理解语义分割解码器设计
引言
在语义分割任务中,解码器(Decoder) 的设计直接影响模型对特征图的上采样能力和细节恢复效果。作为开源语义分割框架 mmsegmentation 的核心组件,decode_heads 提供了多种经典和前沿的解码器实现。
本文将深入解析 mmseg 中常见的 decode_heads,包括其核心思想、结构设计、适用场景及代码实现,帮助读者理解语义分割解码器的技术脉络。
MMSEGMENTATION官方文档
深度学习pytorch之简单方法自定义9种卷积即插即用
1. 什么是decode_head?
在语义分割模型中,典型的架构为 Encoder-Decoder:
-
Encoder(如ResNet、Swin Transformer)负责提取多尺度特征。
-
Decoder(即 decode_head)负责将低分辨率特征图逐步上采样,恢复空间细节并生成最终分割掩码。
decode_head 的核心任务:融合多级特征,平衡全局语义与局部细节。
2. mmsegmentation中的decode_heads解析示例
以下按类别介绍常见解码器,结合代码与示意图说明其设计思想。
2.1 FCNHead:最基础的解码器
核心思想:直接对Encoder输出的特征图进行卷积和上采样。
结构解析:
# mmseg/models/decode_heads/fcn_head.py
class FCNHead(BaseDecodeHead):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.conv_seg = nn.Conv2d(self.channels, self.num_classes, kernel_size=1)
def forward(self, inputs):
output = self._transform_inputs(inputs) # 选择指定层特征
output = self.conv_seg(output)
output = resize(output, size=img_size, mode='bilinear')
return output
适用场景:简单分割任务(如二分类),计算资源有限时。
优缺点:
-
✅ 结构简单,计算量小。
-
❌ 无法融合多尺度信息,细节恢复能力弱。
2.2 PSPNet:金字塔池化模块
核心思想:出自PSPNet,通过金字塔池化(Pyramid Pooling Module, PPM) 捕获多尺度上下文信息。
结构解析:
#mmseg/models/decode_heads/psp_head.py
class PSPHead(BaseDecodeHead):
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super().__init__(**kwargs)
self.psp_modules = ModuleList([
PSPModule(pool_scale) for pool_scale in pool_scales
])
def forward(self, inputs):
x = self._transform_inputs(inputs)
psp_outs = [x] + [psp_module(x) for psp_module in self.psp_modules]
x = torch.cat(psp_outs, dim=1)
x = self.conv_seg(x)
return resize(x, size=img_size, mode='bilinear')
适用场景:需要全局上下文信息的任务(如场景解析)。
优缺点:
-
✅ 多尺度池化增强全局感知。
-
❌ 池化操作可能丢失局部细节。
2.3 DeepLab系列
2.3.1 DeepLabV3 & DeepLabV3+
核心思想:空洞空间金字塔池化(ASPP) + Decoder细化。
结构解析:
# mmseg/models/decode_heads/deeplabv3_head.py
class ASPPHead(BaseDecodeHead):
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
super().__init__(**kwargs)
self.aspp_modules = ModuleList([
ASPPModule(dilation) for dilation in dilations
])
def forward(self, inputs):
x = self._transform_inputs(inputs)
aspp_outs = [module(x) for module in self.aspp_modules]
x = torch.cat(aspp_outs, dim=1)
x = self.conv_seg(x)
return resize(x, size=img_size, mode='bilinear')
出自DeepLabV3+,适用场景:复杂场景下的高精度分割(如Cityscapes)。
优缺点:
-
✅ ASPP有效扩大感受野,兼顾多尺度。
-
❌ 计算量较大,需高显存支持。
2.4 UPerNet:统一金字塔上下文融合
核心思想:通过特征金字塔网络(FPN) 融合多级特征。
结构解析:
# mmseg/models/decode_heads/uper_head.py
class UPerHead(BaseDecodeHead):
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super().__init__(**kwargs)
# 构建PPM和FPN
self.psp_modules = PPM(pool_scales)
self.fpn_modules = FPN(in_channels_list, out_channels)
def forward(self, inputs):
psp_out = self.psp_modules(inputs[-1]) # 使用最后层特征
fpn_outs = self.fpn_modules([psp_out] + inputs[:-1])
# 逐层上采样融合
...
适用场景:需要多层次特征融合的任务(如ADE20K)。
优缺点:
-
✅ 显式融合不同层特征,细节恢复能力强。
-
❌ 结构复杂,训练时间较长。
2.5 其他decode_heads
2.5.1 ANNHead(注意力引导网络)
核心思想:通过轴向注意力(Axial Attention) 增强长距离依赖建模。
适用场景:高分辨率图像分割(如医疗影像)。
2.5.2 CCHead(Contextual Contrast Head)
核心思想:引入上下文对比损失,增强类别边界区分度。
适用场景:类别边界模糊的任务(如遥感影像)。
3. 如何选择合适的decode_head?
仅对测试过的部分说明:
场景需求 | 推荐decode_head |
---|---|
简单任务,计算资源有限 | FCNHead |
需要全局上下文 | PSPHead/ASPPHead |
多层次特征融合需求高 | UPerHead |
高分辨率细节恢复 | ANNHead/PointRend |
类别边界模糊 | CCHead/BoundaryHead |
4. 自定义decode_head的实践建议
继承BaseDecodeHead:复用基础结构(如损失计算、上采样)。
from mmseg.models.decode_heads import BaseDecodeHead
class CustomHead(BaseDecodeHead):
def __init__(self, custom_param, **kwargs):
super().__init__(**kwargs)
self.custom_layer = nn.Conv2d(...)
特征融合创新:尝试跨层注意力、动态卷积等机制。
结合领域知识:如医学影像中引入形状先验。
5. 总结
mmsegmentation 的 decode_heads 提供了丰富的解码器设计范式,从基础的 FCNHead 到融合注意力机制的 ANNHead,开发者可根据任务需求灵活选择。未来趋势将更加注重轻量化设计与动态自适应特征融合。