【Block总结】PagFM,像素注意力引导融合模块|即插即用
论文总结:PIDNet: A Real-time Semantic Segmentation Network Inspired by PID
论文信息
- 标题: PIDNet: A Real-time Semantic Segmentation Network Inspired by PID
- 作者: 论文地址
- 研究领域: 实时语义分割
- 应用场景: 自动驾驶、机器人视觉、嵌入式设备等对实时性要求高的场景。
创新点
-
引入PID控制器思想:
- 首次将比例-积分-微分(PID)控制器的概念引入深度学习语义分割任务。
- 传统双分支网络(TBN)被视为比例-积分(PI)控制器,但PI控制器存在“过冲”问题。
- 通过增加一个导数(D)分支,模拟PID控制器,缓解过冲问题,提升分割性能。
-
提出PIDNet三分支架构:
- P分支(Proportional Branch): 专注于解析和保留高分辨率的细节信息。
- I分支(Integral Branch): 聚合全局上下文信息,捕获长距离依赖关系。
- D分支(Derivative Branch): 提取高频特征,用于精确预测边界。
-
设计高效模块:
- Pag模块(Pixel-attention-guided fusion module): 实现从I分支到P分支的信息传递,避免细节特征被上下文信息淹没。
- Bag模块(Boundary-attention-guided fusion module): 通过边界注意力机制优化P分支与I分支特征的融合,平衡细节与上下文信息。
- PAPPM模块(Parallel Aggregation PPM): 改进上下文聚合模块,通过并行结构加快计算,提升实时性能。
-
推理速度与精度的权衡:
- 在Cityscapes和CamVid数据集上,PIDNet在保持高推理速度的同时实现了更高的分割精度,优于现有实时语义分割模型。
方法
1. 模型整体架构
PIDNet是一种三分支架构的实时语义分割模型,分别通过P分支解析细节、I分支聚合上下文、D分支提取边界信息,同时利用Pag和Bag模块优化细节与上下文的融合,PAPPM模块加速上下文聚合。
-
三分支设计:
- P分支(Proportional Branch): 处理高分辨率的细节特征,聚焦于像素级信息的解析,用于保留目标的几何形状和边界。
- I分支(Integral Branch): 聚合局部和全局的上下文信息,用于捕获长距离依赖关系,解析全局语义。
- D分支(Derivative Branch): 预测边界信息,提取高频特征,通过检测边界信息辅助细节和上下文特征的融合,提升分割精度。
-
模块设计:
- Pag模块: 基于像素级注意力机制,有选择性地融合有用的语义信息,实现从I分支到P分支的信息传递,避免细节被上下文信息淹没。
- Bag模块: 在融合细节和上下文特征时,引入边界注意力,增强小目标和边界区域的细节特征。
- PAPPM模块: 用于快速聚合多尺度上下文信息,通过并行化结构减少计算时间,提升实时性。
2. 网络运行流程
- 输入图像经过主干网络(基于残差块的骨干网络),提取初步特征。
- 特征分别流向三个分支:
- P分支: 解析局部细节特征。
- I分支: 聚合全局上下文信息。
- D分支: 提取边界信息,用于辅助优化。
- 在输出阶段,利用Bag模块结合D分支的边界信息,将P分支和I分支的特征进行加权融合,生成最终的分割结果。
3. 损失函数设计
- 多任务损失:
- 语义损失(Cross-Entropy Loss, l₀): 优化全局分割任务。
- 边界损失(Weighted Binary Cross-Entropy Loss, l₁): 强调边界区域的重要性。
- 边界感知损失(Boundary-Awareness CE Loss, l₂): 协调语义分割与边界检测的功能。
- 总损失函数:
L o s s = λ 0 l 0 + λ 1 l 1 + λ 2 l 2 Loss = \lambda_0 l_0 + \lambda_1 l_1 + \lambda_2 l_2 Loss=λ0l0+λ1l1+λ2l2
4. PagFM模块详解
PagFM 是 PIDNet 中的 Pixel-attention-guided fusion module(像素注意力引导融合模块),其主要作用是实现从 I 分支到 P 分支的信息传递,同时避免细节特征被上下文信息淹没。通过像素级注意力机制,PagFM 有选择性地融合有用的语义信息,从而在细节和上下文信息之间找到平衡。
PagFM模块的作用:
-
像素级注意力机制:通过计算细节特征和上下文特征的相似性,动态调整两者的融合权重。避免上下文信息淹没细节信息。
-
信息传递:将 I 分支的上下文信息传递到 P 分支,同时保留 P 分支的细节特征。
-
高效设计:使用 1x1 卷积和降维操作,减少计算量。可选的通道注意力机制进一步增强特征融合能力。
实验与效果
1. 数据集与评估
- 数据集: Cityscapes、CamVid
- 评估指标: mIoU(平均交并比)、FPS(推理速度)
2. 消融实验
-
ADB和Bag模块的作用:
- 在双分支网络(BiSeNet和DDRNet)中引入ADB和Bag模块后,分割精度(mIoU)显著提升,但推理速度(FPS)略有下降。
- 说明ADB和Bag模块在提升精度方面效果显著,但需权衡计算代价。
-
Pag和Bag模块的作用:
- 在PIDNet-L中,结合Pag和Bag模块后,模型性能提升到最高(mIoU达到80.9%)。
- 验证了Pag模块在细节信息提取中的重要性,Bag模块在特征融合中的关键作用。
3. 性能对比
- Cityscapes数据集:
- PIDNet在mIoU和FPS上均优于现有实时语义分割模型(如BiSeNet、DDRNet)。
- CamVid数据集:
- PIDNet在小目标和边界区域的分割上表现尤为突出。
效果与应用
-
实时处理能力:
- 通过轻量化设计(如PAPPM和Bag模块),确保分割模型在嵌入式设备或资源有限的场景中也能实时运行。
-
边界优化与目标增强:
- 边界感知模块(D分支和Bag模块)在边界细节和小目标分割上表现尤为突出,适合对边界精确性要求高的任务。
-
即插即用模块:
- PIDNet作为一个即插即用模块,能够平衡细节与上下文信息:
- P分支: 提取高分辨率的细节特征,确保小目标和边界的准确性。
- I分支: 提供全局上下文信息,解析复杂场景和大目标的语义。
- D分支: 提取边界特征,强化小目标和边界的清晰度。
- PIDNet作为一个即插即用模块,能够平衡细节与上下文信息:
总结
PIDNet通过引入PID控制器的思想,设计了三分支架构(P、I、D分支)和高效模块(Pag、Bag、PAPPM),在保持高推理速度的同时实现了优异的分割精度。其在Cityscapes和CamVid等数据集上的表现表明,该模型适用于自动驾驶等对实时性和精度要求高的场景。
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class PagFM(nn.Module):
def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):
super(PagFM, self).__init__()
self.with_channel = with_channel
self.after_relu = after_relu
self.f_x = nn.Sequential(
nn.Conv2d(in_channels, mid_channels,
kernel_size=1, bias=False),
BatchNorm(mid_channels)
)
self.f_y = nn.Sequential(
nn.Conv2d(in_channels, mid_channels,
kernel_size=1, bias=False),
BatchNorm(mid_channels)
)
if with_channel:
self.up = nn.Sequential(
nn.Conv2d(mid_channels, in_channels,
kernel_size=1, bias=False),
BatchNorm(in_channels)
)
if after_relu:
self.relu = nn.ReLU(inplace=True)
def forward(self, x, y):
input_size = x.size()
if self.after_relu:
y = self.relu(y)
x = self.relu(x)
y_q = self.f_y(y)
y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]],
mode='bilinear', align_corners=False)
x_k = self.f_x(x)
if self.with_channel:
sim_map = torch.sigmoid(self.up(x_k * y_q))
else:
sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))
y = F.interpolate(y, size=[input_size[2], input_size[3]],
mode='bilinear', align_corners=False)
x = (1 - sim_map) * x + sim_map * y
return x
if __name__ == "__main__":
# 定义输入张量大小(Batch、Channel、Height、Wight)
B, C, H, W = 16, 512, 40, 40
input_tensor = torch.randn(B,C,H,W) # 随机生成输入张量
input_tensor1 = torch.randn(B, C, H, W) # 随机生成输入张量
dim=C
# 创建 ARConv 实例
block = PagFM(in_channels=dim,mid_channels=dim//4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sablock = block.to(device)
print(sablock)
input_tensor = input_tensor.to(device)
input_tensor1 = input_tensor1.to(device)
# 执行前向传播
output = sablock(input_tensor,input_tensor1)
# 打印输入和输出的形状
print(f"Input: {input_tensor.shape}")
print(f"Output: {output.shape}")
代码详解
1. 模块功能概述
PagFM
(Position-wise Adaptive Gated Feature Modulation)是一种特征融合模块,通过计算两个输入特征图的空间相似性,动态调整它们的融合权重。其核心思想是:
- 特征对齐:通过插值对齐不同分辨率的特征图。
- 相似性建模:计算位置级别的相似度图(SimMap)。
- 自适应融合:根据相似度图加权融合输入特征。
2. 代码逐行解析
2.1 初始化方法 __init__
def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):
super(PagFM, self).__init__()
self.with_channel = with_channel
self.after_relu = after_relu
# 特征变换分支(对x和y分别进行通道降维)
self.f_x = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
BatchNorm(mid_channels)
)
self.f_y = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
BatchNorm(mid_channels)
)
# 通道维度的相似性建模(可选)
if with_channel:
self.up = nn.Sequential(
nn.Conv2d(mid_channels, in_channels, kernel_size=1, bias=False),
BatchNorm(in_channels)
)
# ReLU激活(可选)
if after_relu:
self.relu = nn.ReLU(inplace=True)
-
参数说明:
in_channels
:输入特征通道数。mid_channels
:中间降维后的通道数(通常为输入通道的1/4)。after_relu
:是否在特征变换前应用ReLU激活。with_channel
:是否启用通道维度的相似性建模。BatchNorm
:批归一化层类型。
-
子模块:
f_x
和f_y
:通过 1x1 卷积将输入特征降维到mid_channels
,减少计算量。up
(可选):将相似性图从mid_channels
恢复到in_channels
,用于通道级融合。relu
(可选):增强特征非线性。
2.2 前向传播 forward
def forward(self, x, y):
input_size = x.size()
# 可选:应用ReLU激活
if self.after_relu:
y = self.relu(y)
x = self.relu(x)
# 特征变换与对齐
y_q = self.f_y(y) # 对y降维
y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]], mode='bilinear', align_corners=False)
x_k = self.f_x(x) # 对x降维
# 计算相似性图 (SimMap)
if self.with_channel:
# 通道级相似性:x_k * y_q → 卷积恢复通道 → Sigmoid
sim_map = torch.sigmoid(self.up(x_k * y_q))
else:
# 空间级相似性:通道求和 → 添加维度 → Sigmoid
sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))
# 特征融合
y = F.interpolate(y, size=[input_size[2], input_size[3]], mode='bilinear', align_corners=False)
x = (1 - sim_map) * x + sim_map * y # 加权融合
return x
- 步骤解析:
- 特征对齐:通过双线性插值将
y
的分辨率调整到与x
相同。 - 相似性计算:
- 通道级(
with_channel=True
):逐通道计算相似度,保留通道信息。 - 空间级(
with_channel=False
):跨通道求和,生成空间注意力图。
- 通道级(
- 特征融合:根据相似度图动态融合
x
和y
,公式为x = (1 - sim_map) * x + sim_map * y
。
- 特征对齐:通过双线性插值将
3. 关键组件解析
3.1 1x1 卷积的作用
- 降维:减少特征通道数,降低计算复杂度。
- 特征变换:将输入特征映射到同一语义空间,便于相似性计算。
3.2 相似性图(SimMap)
- 生成方式:
- 空间级:跨通道求和,生成单通道的注意力图,关注空间位置的重要性。
- 通道级:保留通道维度,生成多通道的相似性图,增强通道间的交互。
- Sigmoid 激活:将相似度值压缩到 [0,1] 区间,作为加权系数。
3.3 双线性插值
- 功能:对齐不同分辨率的特征图(如
y
可能来自深层网络,分辨率较低)。 - 模式选择:
align_corners=False
避免边缘像素对齐问题,适用于特征图插值。
4. 示例与测试
if __name__ == "__main__":
B, C, H, W = 16, 512, 40, 40
input_tensor = torch.randn(B, C, H, W) # 主特征(如浅层特征)
input_tensor1 = torch.randn(B, C, H//2, W//2) # 辅助特征(如深层特征)
# 创建 PagFM 实例(中间通道为128)
block = PagFM(in_channels=C, mid_channels=C//4, with_channel=True)
output = block(input_tensor, input_tensor1)
print(f"Input shape: {input_tensor.shape}") # [16, 512, 40, 40]
print(f"Output shape: {output.shape}") # [16, 512, 40, 40]
- 输入:两个不同分辨率的特征图(
input_tensor1
被插值到与input_tensor
相同尺寸)。 - 输出:融合后的特征图,尺寸与
input_tensor
一致。