SAM2学习笔记
1. Prompt
是否可以请您参考PyTorch的文档格式和文档风格,使用Markdown格式为 amg.generate函数编写一段相应的文档说明呢?
2. Documentation
2.1 SAM2AutomaticMaskGenerator.generate
SAM2AutomaticMaskGenerator.generate(image)
使用SAM2模型对整张图像进行全自动实例分割,无需任何提示。该方法在图像上生成点网格作为提示,然后过滤低质量和重复的掩码。
Parameters
- image (np.ndarray) – 待分割的输入图像,格式为
HWC(高度×宽度×通道),数据类型为uint8,通道顺序为RGB。
Returns
包含所有检测到的实例掩码的列表,列表中的每个元素为一个字典,包含以下键值对:
-
segmentation (np.ndarray or dict) – 分割掩码。
- 当
output_mode='binary_mask'时:形状为(H, W)的二值数组 - 当
output_mode='coco_rle'或'uncompressed_rle'时:包含 RLE 编码的字典
- 当
-
bbox (list[float]) – 掩码的边界框,格式为
[x, y, width, height](XYWH) -
area (int) – 掩码区域的像素数量
-
predicted_iou (float) – 模型预测的掩码质量分数,范围
[0, 1],已通过pred_iou_thresh阈值过滤 -
point_coords (list[list[float]]) – 用于生成该掩码的提示点坐标
[[x, y]] -
stability_score (float) – 掩码稳定性分数,范围
[0, 1],已通过stability_score_thresh阈值过滤 -
crop_box (list[float]) – 生成掩码时使用的图像裁剪区域,格式为
[x, y, width, height](XYWH)
Return type: list[dict[str, Any]]
Shape
- Input:
(H, W, 3)其中H和W可以是任意尺寸 - Output segmentation:
(H, W)与输入图像尺寸相同
Examples
基础用法
import cv2
import numpy as np
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator# 构建模型
sam2 = build_sam2(config_file="configs/sam2.1/sam2.1_hiera_t.yaml",ckpt_path="checkpoints/sam2.1_hiera_tiny.pt"
)# 初始化自动掩码生成器
mask_generator = SAM2AutomaticMaskGenerator(sam2)# 读取图像(注意:cv2.IMREAD_COLOR_RGB 需要 OpenCV >= 4.5)
image = cv2.imread("image.jpg", cv2.IMREAD_COLOR_RGB)# 生成掩码
masks = mask_generator.generate(image)print(f"检测到 {len(masks)} 个实例")
for i, mask_data in enumerate(masks):print(f"实例 {i}: 面积={mask_data['area']}, IoU={mask_data['predicted_iou']:.3f}")
自定义参数
# 使用更严格的阈值以减少噪声
mask_generator = SAM2AutomaticMaskGenerator(sam2,points_per_side=32, # 点网格密度(默认 32)pred_iou_thresh=0.88, # 提高 IoU 阈值(默认 0.8)stability_score_thresh=0.95, # 提高稳定性阈值(默认 0.95)crop_n_layers=1, # 使用多尺度裁剪(默认 0)min_mask_region_area=100, # 过滤小于 100 像素的区域
)masks = mask_generator.generate(image)
可视化结果
import matplotlib.pyplot as pltdef show_anns(anns):"""可视化所有掩码"""if len(anns) == 0:returnsorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))img[:, :, 3] = 0for ann in sorted_anns:m = ann['segmentation']color_mask = np.concatenate([np.random.random(3), [0.35]])img[m] = color_maskplt.imshow(img)plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
Notes
⚠️ 通道顺序
输入图像必须为 RGB 格式。如果使用cv2.imread()默认读取(BGR 格式),需要先转换:image_bgr = cv2.imread("image.jpg") image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)或使用 OpenCV 4.5+ 的
cv2.IMREAD_COLOR_RGB标志直接读取 RGB。
💡 性能优化
- 增大
points_per_side可以检测更多小目标,但会显著增加计算时间- 启用
crop_n_layers > 0可提高多尺度目标的检测,但会增加 2ⁿ 倍计算量- 调高
pred_iou_thresh和stability_score_thresh可减少低质量掩码
📊 内存消耗
对于大分辨率图像(如 4K),使用output_mode='binary_mask'可能消耗大量内存。建议:
- 使用
output_mode='coco_rle'进行压缩存储- 或在生成后立即处理并释放掩码
