当前位置: 首页 > news >正文

分割模型Maskformer

MaskFormer

背景

语义分割:任务是为图像中每一个像素分配一个类别标签,传统方式通常视为逐像素分类,模型会输出一个与输入图像尺寸相同的特征图,每个位置是一个类别概率向量。

实例分割:任务不仅需要区分类别,还要区分同一类别的不同个体,传统上,这类任务通常使用掩码分类,即模型先检测出物体框,再为每个框预测一个二进制掩码

本文提出:

掩码分类(mask classification)本身就足够通用,可以用完全相同的模型、损失函数和训练流程,以统一的方式同时解决语义级和实例级的分割问题。

语义分割=实例分割+实例分类
在这里插入图片描述

逐像素分类(左边)

  • 模型最终输出是 [H x W x 类别数]
  • 使用“逐像素分类损失”

掩码分类(右边):

  • 掩码分类预测一组二进制掩码,并为每个掩码分配一个类。
  • 每个像素的二进制掩码损失和分类损失
模型结构

在这里插入图片描述
MaskFormer包含三个模块:

  • 像素级模块:其职责是处理输入图像,提取图像特征,并生成高分辨率,精细的像素级特征表示,这特征是后续生成掩码的基础
  • Transformer模块:它接收来自像素级模块的信息以及一组可学习的查询(query),通过自注意力和交叉注意力机制,输出 N 个全局的、抽象的特征向量。每个向量都编码了图像中某个潜在物体或区域的全局信息。
  • 分割模块:这是一个轻量级的预测头,它将 Transformer 模块输出的每个抽象特征向量,分别转换为:
    • 一个类别概率分布(pip_ipi):预测这个向量所代表的区域属于哪个类别
    • 一个二进制掩码(mjm_jmj):预测这个区域在图像中具体的像素级位置

在推理时,模型输出的就是这N个对,再通过简单的规则,就能够组装成最终的语义分割图或者实例分割图

像素级模块
  • Backbone:用于提取图像特征,其输出通常是一个空间分辨率较低,但通道数丰富,语义信息强的特征图
  • 像素解码器(Pixel Decoder):负责将Backbone输出的低分辨率特征图逐步上采样,恢复到与原图相同的大小(H x W),这个过程的输出不再是简单的特征图,而是被称作 “逐像素嵌入”,它为每个像素位置都赋予了一个特征向量
Transformer 模块
  • 输入
    • 由Backbone提取的,包含丰富视觉信息的特征图
    • 可学习查询向量(Queries):N 个可训练的向量,可以理解为模型需要寻找的“N种不同目标或者区域的模板”
  • 过程:通过Transformer解码器的交叉注意力机制,每个“查询”都会主动地去“查询”和“收集”整个图像特征 F中与自己相关的信息。
  • 输出:经过多层计算后,每个查询向量都变成了一个信息丰富的“每片段嵌入” Q,它编码了某个特定目标的全局信息
分割模块:从抽象嵌入到具体预测

此模块负责将Transformer输出的抽象嵌入Q解码为具体的类别和掩码预测,它包含两条并行的通路:

  • 类别预测通路

    • 一个线性分类器(全连接层)接一个softmax函数,直接作用在每个片段嵌入 Q上,输出一个 K+1维的概率分布 pip_ipi,表示这个片段属于各个类别(包括“无对象”)的概率。
  • 掩码预测通路

    • 步骤一(生成掩码嵌入): 用一个小的MLP将每个全局的片段嵌入 Q转换为一个“掩码嵌入” E_mask。这个掩码嵌入可以看作是该目标掩码的特征编码。
    • 步骤2(生成掩码本身): 这是非常巧妙的一步。掩码的生成是通过计算掩码嵌入 E_mask 与像素级模块输出的高分辨率逐像素嵌入 E_pixel的点积(相似度)来实现的。
损失函数:
  • 分类损失:标准的交叉熵损失,用于优化类别预测的准确性
  • 掩码损失: 用于优化预测掩码的形状准确性
    • Focal Loss:对难分类的像素点(例如边界)给予更高的关注,优化掩码的细节
    • Dice Loss:接优化预测掩码和真实掩码之间的重叠面积(交并比),非常适用于评估分割效果
http://www.dtcms.com/a/391497.html

相关文章:

  • C# TCP的方式 实现上传文件
  • 高压消解罐:难溶物质消解的首选工具
  • JavaScript 字符串截取最后一位的几种方法
  • MobileNetV3训练自定义数据集并通过C++进行推理模型部署
  • nvshmem源码学习(一)ibgda视角的整体流程
  • Redis群集的三种模式
  • 鸿蒙(南向/北向)
  • Spring IoCDI 快速入门
  • MySQL的C语言驱动核心——`mysql_real_connect()` 函数
  • C++线程池学习 Day06
  • React 样式CSS的定义 多种定义方式 前端基础
  • react+anddesign组件Tabs实现后台管理系统自定义页签头
  • Midscene 低代码实现Android自动化
  • ADB使用指南
  • FunCaptcha如何查找sitekey参数
  • 大模型如何让机器人实现“从冰箱里拿一瓶可乐”?
  • Python实现液体蒸发优化算法 (Evaporation Rate Water Cycle Algorithm, ER-WCA)(附完整代码)
  • MySQL 数据库的「超级钥匙」—`mysql_real_connect`
  • LeetCode 每日一题 3484. 设计电子表格
  • RAGAS深度解析:引领RAG评估新时代的开源技术革命
  • aave v3.4 利率计算详解
  • rook-ceph CRD资源配置时效问题
  • MySQL学习笔记-进阶篇
  • Rust 关键字
  • 排版使用latex排版还是word排版更容易通过mdpi remote sensing的审稿?
  • Qt QML ToolTip弹出方向控制问题探讨
  • [Windows] PDFQFZ(PDF加盖骑缝章) v1.31
  • 四网络层IP-子网掩码-路由表-真题
  • 安装QT6.9.2
  • 使用 NodePort