当前位置: 首页 > news >正文

Mask2Former,分割新范式


Mask2Former前向传播详解 (基于您的描述)

整个前向传播过程是一个迭代式的优化流程,以可学习的查询(Query)为核心,在每一层解码器中不断与图像特征交互,并细化自身的预测。

核心组件
  • query: 可学习的查询嵌入。
    • Shape: (2, 27, 192)
  • pos: 对应查询的位置编码。
    • Shape: (2, 27, 192)
  • 高分辨率特征: 来自像素解码器的、用于生成精细掩码的特征图。
    • Shape: (2, 192, 224, 224)
  • 中间层特征: 来自像素解码器的、用于注意力计算的特征图,已展平。
    • Shape: (2, 192, 784)

单层解码器迭代流程

以下流程会在每个解码器层中重复执行,每一层的输出查询会作为下一层的输入查询。

第 1 步:并行的初始预测

在进行注意力计算之前,当前层的输入query会首先生成三个并行的初始预测,用于指导后续操作和最终输出。

  1. 类别预测 (output_class):

    • 过程: query 通过一个专用的线性层。
    • 目的: 表示27个查询与18个类别的专精(匹配)程度。
    • 输出Shape: (2, 27, 18)
  2. 高分辨率掩码预测 (output_mask):

    • 过程: query 首先通过一个MLP进行变换,然后其输出与高分辨率特征(2, 192, 224, 224))进行矩阵乘法。
    • 目的: 计算每个像素与27个“专家”(查询)的相似度,或者说,像素属于哪个专家的比例。
    • 输出Shape: (2, 27, 224, 224)
  3. 注意力掩码生成 (mask):

    • 过程: 上一步生成的高分辨率output_mask下采样到与中间特征层匹配的分辨率,然后展平 (flatten)
    • 目的: 生成一个低分辨率的掩码,用于指导后续的交叉注意力计算。
    • 输出Shape: (2, 27, 784)

第 2 步:掩码交叉注意力 (Masked Cross-Attention)

  • 目的: 让查询有选择性地从图像的关键区域中提取特征。
  • 过程:
    1. Q (Query): 由 query 加上其位置编码 pos 生成。
    2. K (Key), V (Value): 由中间层特征(2, 192, 784))加上其对应的位置编码生成。
    3. 执行交叉注意力操作,在计算原始注意力分数后,加上之前预测的 mask(2, 27, 784))来调整注意力分布。
    4. 归一化后的注意力权重(Shape: (2, 27, 784))乘以 V
    5. 结果与原始query进行残差连接。
  • 输出: 吸收了图像特征的查询嵌入,Shape为 (2, 27, 192)

第 3 步:自注意力 (Self-Attention)

  • 目的: 使27个查询(专家)之间交换信息,协同工作。
  • 过程:
    1. 将上一步交叉注意力的输出作为输入。
    2. 执行标准的多头自注意力。
  • 输出: 经过内部信息交互后,得到本层最终的输出查询,Shape为 (2, 27, 192)

这个输出查询将作为新的 query,进入下一个解码器层,重复上述第1至第3步的完整流程。


最终合成(推理阶段)

在经过所有解码器层的迭代后,我们取最后一层生成的output_classoutput_mask来生成最终的分割结果。

  1. 获取最终预测:

    • output_class: 代表各个专家专精各个类别的比例(或分数)。
      • Shape: (2, 27, 18)
    • output_mask: 代表每个像素属于各个专家的比例(或分数)。
      • Shape: (2, 27, 224, 224)
  2. 融合生成分割分数:

    • 过程: 将output_classoutput_mask进行矩阵乘法。这在实现上通常通过torch.einsum来高效完成。
    • 目的: 将每个专家的类别判断(权重)应用到它所标识的空间区域上,然后将所有专家的贡献累加起来。
    • 输出: 得到各个类别的分割分数图。
      • Shape: (2, 18, 224, 224)
  3. 最终决策:

    • 过程: 在最后一个维度(类别维度,即18个类别分数)上,为每个像素执行**argmax**操作。
    • 目的: 为每个像素选出得分最高的那个类别作为其最终的分类。
  • 最终输出:
    • 一张整数类型的语义分割图。
    • Shape: (2, 224, 224)

Mask2Former训练损失

Mask2Former的训练核心在于将分割任务统一为集合预测 (Set Prediction) 问题。其损失构造分为两个核心阶段:首先通过二分图匹配为每个查询预测动态分配监督目标,然后基于该分配计算具体损失。


第零步:真实标签的重构 (Ground Truth Reformation)

核心目的: Mask2Former的损失函数是为处理“实例”集合而设计的。为了用这个统一框架处理语义分割,必须在训练时将标准的像素级语义标签图,重构为一个“伪实例”集合。这并非“适配”,而是将不同任务统一到同一套输入范式下的必要步骤。

输入:

  • 一批标准的语义分割真实标签图 gt_masks
  • Shape: (B, H, W)
  • 内容: gt_masks[b, h, w] 的值是该像素的类别ID。

重构流程 (以批次中的单张图 mask 为例):

  1. 识别图中出现的语义类别:

    • 操作: torch.unique(mask)
    • 目的: 获取这张图上存在的所有类别ID的集合。
    • 示例: 对于一张包含天空(ID=1)和道路(ID=2)的图,此操作返回 tensor([0, 1, 2]) (0为背景)。
  2. 筛选有效类别ID:

    • 操作: 剔除背景类别ID(通常为0)。
    • 目的: 确定需要被监督的语义区域。
    • 示例: 从 tensor([0, 1, 2]) 筛选出 labels = tensor([1, 2])。我们现在知道这张图有 M=2 个需要学习的目标。
  3. 为每个类别生成二值掩码:

    • 操作: 遍历上一步得到的 labels。对于每个类别ID c,生成一个与原图等大的二值掩码,其中所有类别为 c 的像素值为1,其余为0。
    • 示例:
      • 为ID=1(天空)生成 sky_mask
      • 为ID=2(道路)生成 road_mask
    • 将这些掩码堆叠起来,得到一个形状为 (M, H, W) -> (2, H, W) 的掩码张量 binary_masks
  4. 构建最终的目标集合:

    • 将筛选出的类别ID labels 和生成的二值掩码 binary_masks 打包成一个字典。
    • 最终输出 target_dict: {'labels': tensor([1, 2]), 'masks': tensor(2, H, W)}
    • 这个字典就代表了这张图的“真实目标集合”,它将被送入匈牙利匹配器。

第一阶段:匈牙利匹配 (Hungarian Matching)

核心目的: 为模型的 N个查询(预测) 和图中的 M个目标(重构后的伪实例) 之间,建立一个最优的一对一匹配。这个匹配的直接结果是为M个查询指派了明确的、需要学习的目标,而其余N-M个查询则被指派了“无物体”这一特殊目标。

1. 构建代价矩阵 (Cost Matrix)

  • 为批次中的每张图独立构建一个代价矩阵 C
  • 形状: (查询数 N, 目标数 M) -> (27, M)
  • 元素 C[i, j]: 代表“将第 i 个查询”与“第 j 个目标”匹配的代价。代价越低,匹配越优。
  • 代价计算: C[i, j] 由三个子代价加权求和构成:
    • 分类代价: 基于查询 i 的类别预测 pred_logits 与目标 j 的真实类别 labels[j] 之间的差异。
    • 掩码代价: 基于查询 i 的掩码预测 pred_masks 与目标 j 的真实掩码 masks[j] 之间的像素级差异(如Focal Loss)。
    • Dice代价: 基于两个掩码在区域重叠度上的差异(Dice Loss)。

2. 运行匹配算法并输出 indices

  • 使用匈牙利算法在代价矩阵 C 上求解,找到总代价最小的 M 个匹配对。
  • 输出 indices: 一个元组列表,其中每个元素 (source_indices, target_indices) 包含了被匹配上的查询的索引和它们对应的目标的索引

第二阶段:损失构造 (Loss Computation)

有了 indices 这个“任务分配表”,就可以为所有查询计算损失。

1. 分类损失 (loss_ce)

  • 监督对象: 所有 N (27) 个查询
  • 目标构建: 这一步至关重要。
    1. 创建一个形状为 (N,) 的目标向量,初始时全部填充为“无物体”类别ID
    2. 使用 indices 进行填充:对于每一个匹配对 (查询索引i, 目标索引j),将目标向量中第 i 个位置的值,修改为目标 j 的真实类别ID labels[j]
    • 最终结果: 我们得到了一个为所有27个查询都分配了监督目标的向量。其中M个查询的目标是真实类别,N-M个查询的目标是“无物体”。
  • 损失计算: 使用交叉熵损失(或Focal Loss)计算模型的 pred_logits 与这个构建好的目标向量之间的差异。

2. 掩码损失 (Mask Loss)

  • 监督对象: 仅监督那些被 indices 匹配上的 M 个查询
  • 损失计算:
    • 对于每一个匹配对 (查询索引i, 目标索引j),提取出查询 i 预测的掩码和目标 j 的真实掩码。
    • 计算这两个掩码之间的Dice Loss二元交叉熵/Focal Loss

3. 深度监督与总损失

  • 上述的整个“重构->匹配->计算损失”流程,会在解码器的每一个中间层的输出上重复执行,这些损失被称为辅助损失。
  • 最终的总损失是最后一层损失与所有辅助损失的加权和,用于驱动整个模型的参数更新。

Q&A

  • 损失函数为什么这么设计,既然我们能通过前向传播得到分割图,那为什么不直接用分割结果与label计算损失反向传播呢,这样更直接,且预测与训练过程一致
    直接使用分割结果做损失会使专家之间职责划分不清晰。促进专家形成:这种机制鼓励不同的查询发展出不同的专长。因为一个查询如果能稳定地在某个模式(比如“车辆状物体”)上表现出色,它就更有可能在匹配中胜出,从而得到更一致的监督信号,形成正向循环。
http://www.dtcms.com/a/284717.html

相关文章:

  • Kafka 控制器(Controller)详解:架构、原理与实战
  • Python23 —— 标准库(time库)
  • c++列表初始化
  • Dijkstra 算法求解多种操作
  • Stone3D教程:免编码制作在线家居生活用品展示应用
  • 【初始Java】
  • mysql中where字段的类型转换
  • (转)Kubernetes基础介绍
  • SQL增查
  • Windows下odbc配置连接SQL Server
  • .Net将控制台的输出信息存入到日志文件按分钟生成日志文件
  • 【JavaEE进阶】使用云服务器搭建Linux环境
  • Java网络通信:UDP和TCP
  • 关于CDH以及HUE的介绍
  • vue-seo优化
  • Android构建流程与Transform任务
  • 题解:P13311 [GCJ 2012 Qualification] Speaking in Tongues
  • java面向对象-多态
  • 【前端】Power BI自动化指南:从API接入到Web嵌入
  • 旅游管理实训基地建设:筑牢文旅人才培养的实践基石
  • LeetCode热题100—— 238. 除自身以外数组的乘积
  • Pygame创建窗口教程 - 从入门到实践 | Python游戏开发指南
  • 小白学Python,网络爬虫篇(1)——requests库
  • java Integer怎么获取长度
  • 【Jmeter】报错:An error occured:Unknown arg
  • 3.PCL点云合并
  • 为什么选择Selenium自动化测试?
  • 接口黑洞?破!安全堡垒?筑!冰火炼狱?战!MES7114W终极掌控
  • 学习C++、QT---27(QT中实现记事本项目实现行列显示、优化保存文件的功能的讲解)
  • 三、CV_VGGnet