当前位置: 首页 > news >正文

[sam2图像分割] mask_decoder | TwoWayTransformer

第六章:掩膜解码器

欢迎回来

在第五章:提示编码器中,我们看到了SAM-2如何理解你的具体指令(如点击或方框)。现在SAM-2已经能"看"到视觉世界(多亏了图像编码器)并"听懂"你的命令(通过提示编码器),接下来呢?它需要真正执行你要求的任务:绘制分割掩膜!

SAM-2的"艺术家"

想象你有一位技艺高超的艺术家,需要绘制一个特定对象。你向他们展示场景的照片(来自图像编码器的图像特征),并明确告诉他们要画什么以及在哪里画(来自提示编码器的提示嵌入)。艺术家随后综合所有信息,理解你的要求,并细致地绘制出目标对象。

掩膜解码器就是SAM2基础模型中的这位艺术家。它的主要工作是结合图像的高级视觉理解和你的提示指令,智能地**"绘制"实际的分割掩膜**,围绕你指定的对象。它还会给你一个评分,告诉你它对自己的绘制结果有多自信!

解决的问题

掩膜解码器解决的核心问题是**基于复杂的视觉信息和用户提供的提示,生成精确的像素级分割掩膜**。

它是连接AI对图像的内部理解与你的高级请求的桥梁,将抽象的数字转化为对象的具体轮廓。它确保当你点击一只猫时,你会得到一个漂亮且准确的仅包含猫的掩膜,而不是背景或其他对象的一部分。

概念

让我们探索掩膜解码器背后的关键思想:

  1. 信息整合:掩膜解码器是终极的混合器。它接收两种主要成分:

    • 图像特征:来自图像编码器的整个图像的详细视觉理解
    • 提示嵌入:你的具体指令(如点击、方框或掩膜),由提示编码器转化为AI的语言
  2. Transformer"大脑":掩膜解码器的核心是一个强大的神经网络,称为Transformer,具体来说是一个TwoWayTransformer这个Transformer非常擅长让不同的信息(如图像特征和提示嵌入)“对话”。它找出你的提示与图像特征的关系,突出图像中与你的请求相关的部分。

  3. 特殊的绘制"令牌":掩膜解码器使用特殊的数值"绘制工具",称为掩膜令牌IoU令牌

    • 掩膜令牌:这些就像不同的画笔或风格。Transformer处理这些令牌以及你的提示和图像特征,学习用它们生成不同的掩膜预测。SAM-2通常可以为同一对象生成几个略有不同的掩膜,以处理模糊情况(例如重叠对象)。
    • IoU令牌:这是一个特殊令牌,模型用它预测每个生成掩膜的置信度分数。这个分数告诉你模型认为它的掩膜预测有多好(接近1.0的分数表示高置信度)。
  4. 掩膜上采样:Transformer在低分辨率表示上运行以提高效率。生成低分辨率掩膜后,掩膜解码器使用一系列上采样层(如ConvTranspose2d)将掩膜"放大"回输入图像的原始尺寸。这确保最终掩膜细节丰富且完美贴合。

掩膜解码器的使用方式

你不会直接在代码中调用MaskDecoder。相反,它是SAM2基础模型的重要内部组件,当你使用高级工具(如SAM2ImagePredictor)的predict()方法时会被激活

让我们回顾第一章:SAM2图像预测器中的SAM2ImagePredictor示例,看看掩膜解码器在哪里发挥作用:

from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2_hf
import numpy as np
from PIL import Image# 1. 加载核心SAM-2模型和预测器(如前几章所示)
sam_model = build_sam2_hf(model_id="facebook/sam2-hiera-large")
predictor = SAM2ImagePredictor(sam_model)# 2. 设置虚拟图像(图像编码器已内部处理)
my_image = np.zeros((256, 256, 3), dtype=np.uint8)
predictor.set_image(my_image)# 3. 提供提示(提示编码器已内部翻译)
point_coords = np.array([[128, 128]]) # 中心附近的点击
point_labels = np.array([1])         # 标签1表示前景# 此调用*内部*激活掩膜解码器!
masks, scores, low_res_masks_logits = predictor.predict(point_coords=point_coords,point_labels=point_labels,multimask_output=False # 我们要求一个主掩膜以简化
)print(f"预测掩膜形状:{masks.shape}")
print(f"掩膜的置信度分数(IoU):{scores.item():.2f}")
print("掩膜解码器已成功绘制掩膜并给出置信度分数")

说明:当你调用predictor.predict()并传入point_coordspoint_labels时,SAM2ImagePredictor会将这些与预计算的图像特征一起发送给SAM2Base Model

SAM2Base Model内部,掩膜解码器接收这些输入,处理它们,生成原始掩膜预测(low_res_masks_logits),并估计其质量(scores)。

SAM2ImagePredictor随后对这些原始输出进行后处理,给你最终的干净masksscores

幕后揭秘:掩膜解码器的工作原理

让我们揭开层层迷雾,了解掩膜解码器如何协调这一"绘制"过程。

工作流程

SAM2ImagePredictor要求SAM2Base Model生成掩膜时,以下是掩膜解码器内部的简化操作:

  1. 输入到达掩膜解码器接收处理后的image_embeddings(来自图像编码器)、image_pe(图像的位置编码)以及sparse_prompt_embeddingsdense_prompt_embeddings(来自提示编码器)。
  2. 准备Transformer输入:它为内部的TwoWayTransformer准备一组"令牌"。这些令牌包括iou_token(用于置信度预测)、mask_tokens(用于生成掩膜)和你的sparse_prompt_embeddings。它还将dense_prompt_embeddingsimage_pe直接添加到image_embeddings中以丰富它们。
  3. Transformer交互:所有这些输入被送入TwoWayTransformer。这个强大的组件执行多轮"注意力"计算,让提示信息影响图像特征,反之亦然。它有效地在图像特征中找到提示描述的对象。Transformer输出精炼后的令牌和图像特征。
  4. 生成掩膜系数:从精炼的mask_tokens(由Transformer输出)中,掩膜解码器使用小型神经网络(output_hypernetworks_mlps)生成"超网络系数"。可以将其视为绘制掩膜的具体笔触指令。
  5. 上采样与合并:同时,掩膜解码器使用output_upscaling网络获取精炼的图像特征并提高其分辨率。这些上采样后的图像特征就像一张高分辨率画布。生成的"超网络系数"随后与这张高分辨率画布结合,"绘制"出详细的low_res_masks
  6. 预测置信度:精炼的iou_token(也由Transformer输出)被发送到iou_prediction_head(一个小型MLP)。这个头预测iou_predictions,即每个生成掩膜的置信度分数。
  7. 输出掩膜解码器随后将这些原始masksiou_predictions返回给SAM2Base Model,后者将它们传递给SAM2ImagePredictor进行最终调整和清理,然后呈现给你。

以下是此流程的简化序列图:

在这里插入图片描述

代码

在这里插入图片描述

让我们看看sam2/modeling/sam/mask_decoder.pysam2/modeling/sam/transformer.py文件中的关键部分,了解这些步骤如何实现。

  1. 掩膜解码器初始化(__init__
    掩膜解码器创建时(作为SAM2Base Model的一部分),它会设置其主要组件:

    # 摘自sam2/modeling/sam/mask_decoder.py(简化版)
    class MaskDecoder(nn.Module):def __init__(self,transformer_dim: int,transformer: nn.Module, # 这是TwoWayTransformer!num_multimask_outputs: int = 3,# ... 其他参数 ...) -> None:super().__init__()self.transformer = transformer # 核心逻辑组合器self.iou_token = nn.Embedding(1, transformer_dim) # 用于置信度分数的特殊令牌self.num_mask_tokens = num_multimask_outputs + 1self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) # 用于绘制掩膜的特殊令牌self.output_upscaling = nn.Sequential( # 提高掩膜分辨率的网络nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),LayerNorm2d(transformer_dim // 4), # 帮助稳定训练nn.GELU(), # 激活函数nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),nn.GELU(),)self.output_hypernetworks_mlps = nn.ModuleList( # 生成掩膜系数的网络[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)])self.iou_prediction_head = MLP( # 预测置信度分数的网络transformer_dim, 256, self.num_mask_tokens, 3 # 示例维度)
    

    说明掩膜解码器初始化其核心self.transformer(混合器)。它还设置self.iou_tokenself.mask_tokens,这些是学习到的特殊数值,作为Transformer生成置信度分数和各种掩膜输出的提示

    • output_upscaling是一个小型神经网络,用于放大最终掩膜;
    • output_hypernetworks_mlps生成细粒度掩膜细节;
    • iou_prediction_head计算置信度。
  2. MaskDecoder.predict_masks(核心逻辑)
    这是掩膜生成真正发生的主要内部方法。由MaskDecoder.forward方法调用。

    # 摘自sam2/modeling/sam/mask_decoder.py(简化版)
    # 在MaskDecoder类内部
    def predict_masks(self,image_embeddings: torch.Tensor,       # 来自图像编码器image_pe: torch.Tensor,               # 图像的位置编码sparse_prompt_embeddings: torch.Tensor, # 来自提示编码器(点/框)dense_prompt_embeddings: torch.Tensor,  # 来自提示编码器(掩膜输入)# ... 其他参数 ...
    ) -> Tuple[torch.Tensor, torch.Tensor]:# 1. 为Transformer的查询输入组合所有"令牌"# 包括IoU令牌、掩膜令牌和用户的稀疏提示output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)# 扩展令牌以匹配批次大小output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# 2. 为Transformer的键/值输入准备图像特征和位置编码b, c, h, w = image_embeddings.shape # 获取图像尺寸src = image_embeddings + dense_prompt_embeddings # 将图像特征与密集提示结合pos_src = image_pe # 使用图像位置编码# 3. 重塑并运行核心TwoWayTransformer以混合所有信息# Transformer期望B x N_tokens x C,所以将HxW展平为N_tokenssrc_flat = src.flatten(2).permute(0, 2, 1) # B x (H*W) x Cpos_src_flat = pos_src.flatten(2).permute(0, 2, 1) # B x (H*W) x Chs, src_out = self.transformer(src_flat, pos_src_flat, tokens) # hs是令牌特征,src_out是注意力后的图像特征# 4. 分离Transformer输出的IoU和掩膜令牌iou_token_out = hs[:, 0, :] # 'hs'中的第一个令牌通常是IoU令牌mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # 其余令牌是掩膜令牌# 5. 上采样精炼的图像特征(src_out)并生成掩膜预测# 将src_out重塑回类似图像的特征(B, C, H, W)src_out = src_out.transpose(1, 2).view(b, c, h, w)upscaled_embedding = self.output_upscaling(src_out) # 提高分辨率hyper_in_list = []for i in range(self.num_mask_tokens):# 每个掩膜令牌(来自mask_tokens_out)生成绘制掩膜的系数hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))hyper_in = torch.stack(hyper_in_list, dim=1) # 堆叠这些系数# 6. 将系数与上采样图像特征结合以"绘制"掩膜masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# 7. 使用IoU令牌预测置信度分数(IoU)iou_pred = self.iou_prediction_head(iou_token_out)return masks, iou_pred # 返回原始掩膜及其置信度分数
    

    说明:此方法是核心。

它接收丰富的image_embeddingsdense_prompt_embeddings,将它们结合并重塑以供self.transformer使用。

  • 它还将你的sparse_prompt_embeddings与特殊的iou_tokenmask_tokens结合为Transformer的查询。
  • Transformer(self.transformer)随后混合所有这些信息,生成精炼的iou_token_outmask_tokens_out
  • mask_tokens_outself.output_hypernetworks_mlps用于生成系数,随后与upscaled_embedding(来自self.output_upscaling)结合以"绘制"masks
  • 最后,iou_token_outself.iou_prediction_head用于预测iou_pred(置信度分数)。

在这里插入图片描述

  1. TwoWayTransformer.forward(掩膜解码器内部)
    TwoWayTransformer是掩膜解码器内部的智能混合器。以下是其主forward方法的简化版:

    # 摘自sam2/modeling/sam/transformer.py(简化版)
    class TwoWayTransformer(nn.Module):# ... 初始化设置注意力层 ...def forward(self,image_embedding: Tensor, # 精炼的图像特征(来自图像编码器+密集提示)image_pe: Tensor,        # 图像的位置编码point_embedding: Tensor, # 所有组合令牌(IoU、掩膜、稀疏提示)) -> Tuple[Tensor, Tensor]:# 将image_embedding和image_pe从Bx(H*W)xC重塑为Bx(H*W)xC(已展平)# 'queries'是掩膜和提示令牌,'keys'是图像特征queries = point_embeddingkeys = image_embedding# 应用一系列Transformer块# 每层是一个TwoWayAttentionBlock,允许queries和keys交互for layer in self.layers:queries, keys = layer(queries=queries,     # 令牌keys=keys,           # 图像特征query_pe=point_embedding, # 令牌的位置信息key_pe=image_pe,     # 图像特征的位置信息)# 应用从令牌到图像特征的最终注意力层# ...(进一步注意力和归一化)...return queries, keys # 返回精炼的令牌和图像特征
    

    说明TwoWayTransformer接收处理后的image_embedding(来自图像编码器加上密集提示)和point_embedding(你的稀疏提示加上掩膜和IoU令牌)。

  • 它随后通过多个TwoWayAttentionBlock层传递这些输入。这些块包含自注意力(令牌与令牌对话)和交叉注意力(令牌与图像对话,图像与令牌对话),实现深度双向交互。
  • 此过程精炼queries(令牌)和keys(图像特征),使它们准备好用于掩膜解码器中的最终掩膜生成步骤。

总结

掩膜解码器是SAM-2的关键"艺术家",负责将所有信息结合起来精确绘制所需的分割掩膜

通过使用强大的Transformer架构智能结合图像编码器的视觉特征和提示编码器的指令丰富嵌入,它生成准确的掩膜预测及置信度分数。这是将请求转化为具体分割对象的最后一步。

现在我们已经了解SAM-2如何看、听和绘 单张图像,接下来探索它如何在视频中跨时间记忆对象。下一章,我们将深入SAM-2中的"记忆"概念!

下一章:记忆编码器

http://www.dtcms.com/a/532508.html

相关文章:

  • 京东面试题解析:SSO、Token与Redis交互、Dubbo负载均衡等
  • 网站建设哪家效益快做百度推广网站排名
  • RabbitMQ -- 高级特性
  • 克隆网站后台asp.net 网站数据库
  • 零基础新手小白快速了解掌握服务集群与自动化运维(十S四)储存服务-Ceph储存
  • 土壤侵蚀相关
  • 花卉网站建设规划书平台推广计划书模板范文
  • 如何使用C#编写DbContext与数据库连接
  • 从一到无穷大 #52:Lakehouse 不适用时序?打破范式 —— Catalog 架构选型复盘
  • 机器学习 (1) 监督学习
  • 从哪里找网络推广公司网站优化 毕业设计
  • Java如何将数据写入到PDF文件
  • 开发板网络配置
  • 14天备考软考-day1: 计组、操作系统(仅自用)
  • 企业网站模板包含什么有什么软件可以做网站
  • .gitignore 不生效问题——删除错误追踪的文件
  • 深度学习优化器详解
  • 做企业公示的数字证书网站wordpress有识图接口吗
  • 中国商标注册申请官网百度蜘蛛池自动收录seo
  • GitHub 热榜项目 - 日榜(2025-10-26)
  • 数据分析:指标拆解、异动归因类题目
  • 做网站需要那些软件设计建网站
  • Gorm(十二)乐观锁和悲观锁
  • neo4j图数据库笔记
  • 网页网站设计公司有哪些网站排名有什么用
  • 泉州做网站优化哪家好微信推广平台哪里找
  • 如何制作收费网站百度收录个人网站是什么怎么做
  • VsCode + Wsl:终极开发环境搭建指南
  • 深度学习——Logistic回归中的梯度下降法
  • 中国住房和城乡建设网网站学习网站大全