Mask2Former,分割新范式
Mask2Former前向传播详解 (基于您的描述)
整个前向传播过程是一个迭代式的优化流程,以可学习的查询(Query)为核心,在每一层解码器中不断与图像特征交互,并细化自身的预测。
核心组件
query
: 可学习的查询嵌入。- Shape:
(2, 27, 192)
- Shape:
pos
: 对应查询的位置编码。- Shape:
(2, 27, 192)
- Shape:
- 高分辨率特征: 来自像素解码器的、用于生成精细掩码的特征图。
- Shape:
(2, 192, 224, 224)
- Shape:
- 中间层特征: 来自像素解码器的、用于注意力计算的特征图,已展平。
- Shape:
(2, 192, 784)
- Shape:
单层解码器迭代流程
以下流程会在每个解码器层中重复执行,每一层的输出查询会作为下一层的输入查询。
第 1 步:并行的初始预测
在进行注意力计算之前,当前层的输入query
会首先生成三个并行的初始预测,用于指导后续操作和最终输出。
-
类别预测 (
output_class
):- 过程:
query
通过一个专用的线性层。 - 目的: 表示27个查询与18个类别的专精(匹配)程度。
- 输出Shape:
(2, 27, 18)
- 过程:
-
高分辨率掩码预测 (
output_mask
):- 过程:
query
首先通过一个MLP进行变换,然后其输出与高分辨率特征((2, 192, 224, 224)
)进行矩阵乘法。 - 目的: 计算每个像素与27个“专家”(查询)的相似度,或者说,像素属于哪个专家的比例。
- 输出Shape:
(2, 27, 224, 224)
- 过程:
-
注意力掩码生成 (
mask
):- 过程: 上一步生成的高分辨率
output_mask
被下采样到与中间特征层匹配的分辨率,然后展平 (flatten)。 - 目的: 生成一个低分辨率的掩码,用于指导后续的交叉注意力计算。
- 输出Shape:
(2, 27, 784)
- 过程: 上一步生成的高分辨率
第 2 步:掩码交叉注意力 (Masked Cross-Attention)
- 目的: 让查询有选择性地从图像的关键区域中提取特征。
- 过程:
- Q (Query): 由
query
加上其位置编码pos
生成。 - K (Key), V (Value): 由中间层特征(
(2, 192, 784)
)加上其对应的位置编码生成。 - 执行交叉注意力操作,在计算原始注意力分数后,加上之前预测的
mask
((2, 27, 784)
)来调整注意力分布。 - 归一化后的注意力权重(Shape:
(2, 27, 784)
)乘以V
。 - 结果与原始
query
进行残差连接。
- Q (Query): 由
- 输出: 吸收了图像特征的查询嵌入,Shape为
(2, 27, 192)
。
第 3 步:自注意力 (Self-Attention)
- 目的: 使27个查询(专家)之间交换信息,协同工作。
- 过程:
- 将上一步交叉注意力的输出作为输入。
- 执行标准的多头自注意力。
- 输出: 经过内部信息交互后,得到本层最终的输出查询,Shape为
(2, 27, 192)
。
这个输出查询将作为新的 query
,进入下一个解码器层,重复上述第1至第3步的完整流程。
最终合成(推理阶段)
在经过所有解码器层的迭代后,我们取最后一层生成的output_class
和output_mask
来生成最终的分割结果。
-
获取最终预测:
output_class
: 代表各个专家专精各个类别的比例(或分数)。- Shape:
(2, 27, 18)
- Shape:
output_mask
: 代表每个像素属于各个专家的比例(或分数)。- Shape:
(2, 27, 224, 224)
- Shape:
-
融合生成分割分数:
- 过程: 将
output_class
与output_mask
进行矩阵乘法。这在实现上通常通过torch.einsum
来高效完成。 - 目的: 将每个专家的类别判断(权重)应用到它所标识的空间区域上,然后将所有专家的贡献累加起来。
- 输出: 得到各个类别的分割分数图。
- Shape:
(2, 18, 224, 224)
- Shape:
- 过程: 将
-
最终决策:
- 过程: 在最后一个维度(类别维度,即18个类别分数)上,为每个像素执行**
argmax
**操作。 - 目的: 为每个像素选出得分最高的那个类别作为其最终的分类。
- 过程: 在最后一个维度(类别维度,即18个类别分数)上,为每个像素执行**
- 最终输出:
- 一张整数类型的语义分割图。
- Shape:
(2, 224, 224)
Mask2Former训练损失
Mask2Former的训练核心在于将分割任务统一为集合预测 (Set Prediction) 问题。其损失构造分为两个核心阶段:首先通过二分图匹配为每个查询预测动态分配监督目标,然后基于该分配计算具体损失。
第零步:真实标签的重构 (Ground Truth Reformation)
核心目的: Mask2Former的损失函数是为处理“实例”集合而设计的。为了用这个统一框架处理语义分割,必须在训练时将标准的像素级语义标签图,重构为一个“伪实例”集合。这并非“适配”,而是将不同任务统一到同一套输入范式下的必要步骤。
输入:
- 一批标准的语义分割真实标签图
gt_masks
。 - Shape:
(B, H, W)
- 内容:
gt_masks[b, h, w]
的值是该像素的类别ID。
重构流程 (以批次中的单张图 mask
为例):
-
识别图中出现的语义类别:
- 操作:
torch.unique(mask)
- 目的: 获取这张图上存在的所有类别ID的集合。
- 示例: 对于一张包含天空(ID=1)和道路(ID=2)的图,此操作返回
tensor([0, 1, 2])
(0为背景)。
- 操作:
-
筛选有效类别ID:
- 操作: 剔除背景类别ID(通常为0)。
- 目的: 确定需要被监督的语义区域。
- 示例: 从
tensor([0, 1, 2])
筛选出labels = tensor([1, 2])
。我们现在知道这张图有 M=2 个需要学习的目标。
-
为每个类别生成二值掩码:
- 操作: 遍历上一步得到的
labels
。对于每个类别IDc
,生成一个与原图等大的二值掩码,其中所有类别为c
的像素值为1,其余为0。 - 示例:
- 为ID=1(天空)生成
sky_mask
。 - 为ID=2(道路)生成
road_mask
。
- 为ID=1(天空)生成
- 将这些掩码堆叠起来,得到一个形状为
(M, H, W)
->(2, H, W)
的掩码张量binary_masks
。
- 操作: 遍历上一步得到的
-
构建最终的目标集合:
- 将筛选出的类别ID
labels
和生成的二值掩码binary_masks
打包成一个字典。 - 最终输出
target_dict
:{'labels': tensor([1, 2]), 'masks': tensor(2, H, W)}
- 这个字典就代表了这张图的“真实目标集合”,它将被送入匈牙利匹配器。
- 将筛选出的类别ID
第一阶段:匈牙利匹配 (Hungarian Matching)
核心目的: 为模型的 N个查询(预测) 和图中的 M个目标(重构后的伪实例) 之间,建立一个最优的一对一匹配。这个匹配的直接结果是为M个查询指派了明确的、需要学习的目标,而其余N-M个查询则被指派了“无物体”这一特殊目标。
1. 构建代价矩阵 (Cost Matrix)
- 为批次中的每张图独立构建一个代价矩阵
C
。 - 形状:
(查询数 N, 目标数 M)
->(27, M)
。 - 元素
C[i, j]
: 代表“将第i
个查询”与“第j
个目标”匹配的代价。代价越低,匹配越优。 - 代价计算:
C[i, j]
由三个子代价加权求和构成:- 分类代价: 基于查询
i
的类别预测pred_logits
与目标j
的真实类别labels[j]
之间的差异。 - 掩码代价: 基于查询
i
的掩码预测pred_masks
与目标j
的真实掩码masks[j]
之间的像素级差异(如Focal Loss)。 - Dice代价: 基于两个掩码在区域重叠度上的差异(Dice Loss)。
- 分类代价: 基于查询
2. 运行匹配算法并输出 indices
- 使用匈牙利算法在代价矩阵
C
上求解,找到总代价最小的M
个匹配对。 - 输出
indices
: 一个元组列表,其中每个元素(source_indices, target_indices)
包含了被匹配上的查询的索引和它们对应的目标的索引。
第二阶段:损失构造 (Loss Computation)
有了 indices
这个“任务分配表”,就可以为所有查询计算损失。
1. 分类损失 (loss_ce
)
- 监督对象: 所有 N (27) 个查询。
- 目标构建: 这一步至关重要。
- 创建一个形状为
(N,)
的目标向量,初始时全部填充为“无物体”类别ID。 - 使用
indices
进行填充:对于每一个匹配对(查询索引i, 目标索引j)
,将目标向量中第i
个位置的值,修改为目标j
的真实类别IDlabels[j]
。
- 最终结果: 我们得到了一个为所有27个查询都分配了监督目标的向量。其中M个查询的目标是真实类别,N-M个查询的目标是“无物体”。
- 创建一个形状为
- 损失计算: 使用交叉熵损失(或Focal Loss)计算模型的
pred_logits
与这个构建好的目标向量之间的差异。
2. 掩码损失 (Mask Loss)
- 监督对象: 仅监督那些被
indices
匹配上的 M 个查询。 - 损失计算:
- 对于每一个匹配对
(查询索引i, 目标索引j)
,提取出查询i
预测的掩码和目标j
的真实掩码。 - 计算这两个掩码之间的Dice Loss和二元交叉熵/Focal Loss。
- 对于每一个匹配对
3. 深度监督与总损失
- 上述的整个“重构->匹配->计算损失”流程,会在解码器的每一个中间层的输出上重复执行,这些损失被称为辅助损失。
- 最终的总损失是最后一层损失与所有辅助损失的加权和,用于驱动整个模型的参数更新。
Q&A
- 损失函数为什么这么设计,既然我们能通过前向传播得到分割图,那为什么不直接用分割结果与label计算损失反向传播呢,这样更直接,且预测与训练过程一致
直接使用分割结果做损失会使专家之间职责划分不清晰。促进专家形成:这种机制鼓励不同的查询发展出不同的专长。因为一个查询如果能稳定地在某个模式(比如“车辆状物体”)上表现出色,它就更有可能在匹配中胜出,从而得到更一致的监督信号,形成正向循环。