【SAM】eval_coco.py说明
eval_coco.py - SAM模型COCO数据集评估工具
概述
eval_coco.py 是一个用于在COCO数据集上评估Segment Anything Model (SAM)性能的工具脚本。该脚本提供了多种提示类型(prompt types)的评估功能,并计算IoU(Intersection over Union)和Dice系数等标准分割指标。
特性
- 多种提示类型支持:支持边界框(box)、单点(point)和多点(point_multi)提示
- 自动化评估流程:自动加载COCO标注,批量处理图像
- 丰富的可视化:生成评估结果可视化图表和IoU分布直方图
- 详细的指标统计:计算平均IoU、Dice系数及其标准差
- 结果持久化:自动保存评估结果到JSON文件
快速开始
# 直接运行脚本
python eval_coco.py
配置参数
在脚本的main()函数中修改以下参数:
MODEL_TYPE = "vit_l" # 模型类型: "vit_b", "vit_l", "vit_h"
CHECKPOINT_PATH = "checkpoints/sam_vit_l_0b3195.pth" # 模型权重路径
COCO_ROOT = "data/COCO" # COCO数据集根目录
DEVICE = "cuda" # 设备: "cuda" 或 "cpu"
API 参考
核心函数
calculate_iou(pred_mask, gt_mask)
计算预测mask与ground truth mask之间的IoU。
参数:
pred_mask(np.ndarray) - 预测的二值mask,shape为(H, W)gt_mask(np.ndarray) - Ground truth二值mask,shape为(H, W)
返回:
iou(float) - IoU值,范围[0, 1]
calculate_dice(pred_mask, gt_mask)
计算Dice系数(F1分数)。
参数:
pred_mask(np.ndarray) - 预测的二值maskgt_mask(np.ndarray) - Ground truth二值mask
返回:
dice(float) - Dice系数,范围[0, 1]
evaluate_on_coco(predictor, coco, image_ids, data_dir, prompt_type, max_instances_per_image)
在COCO数据集上执行评估。
参数:
predictor(SamPredictor) - SAM预测器实例coco(COCO) - COCO数据集API实例image_ids(List[int]) - 要评估的图像ID列表data_dir(str) - 图像目录路径prompt_type(str, optional) - 提示类型,可选值:"box","point","point_multi"。默认为"box"max_instances_per_image(int, optional) - 每张图像评估的最大实例数。默认为10
返回:
results(Dict) - 包含以下键的字典:avg_iou: 平均IoUavg_dice: 平均Dice系数all_ious: 所有IoU值的列表all_dices: 所有Dice值的列表num_instances: 评估的实例总数sample_results: 用于可视化的示例结果
数据集组织
期望的COCO数据集目录结构:
COCO/
├── annotations_trainval2017/
│ └── annotations/
│ ├── instances_train2017.json
│ └── instances_val2017.json
├── train2017/
│ └── *.jpg
└── val2017/└── *.jpg
输出文件
脚本执行后会生成以下文件:
| 文件名 | 描述 |
|---|---|
coco_evaluation_results.json | 详细的评估结果,包含所有指标 |
coco_eval_box.png | 边界框提示的可视化结果 |
coco_eval_point.png | 单点提示的可视化结果 |
coco_eval_point_multi.png | 多点提示的可视化结果 |
coco_iou_distribution.png | IoU分布直方图 |
示例输出
================================================================================
COCO数据集评估结果汇总
================================================================================
Model Checkpoint Input Size Prompt Type IoU Dice Instances
--------------------------------------------------------------------------------
vit_l sam_vit_l_0b3195 1024 box 0.8234 0.8956 487
vit_l sam_vit_l_0b3195 1024 point 0.6789 0.7612 487
vit_l sam_vit_l_0b3195 1024 point_multi 0.7456 0.8123 487
================================================================================
性能优化建议
- 批处理:当前版本逐个处理实例。可以修改为批量处理以提高效率
- GPU利用:确保CUDA可用并设置
DEVICE="cuda" - 采样策略:通过调整
max_instances_per_image参数平衡评估速度和准确性 - 缓存机制:可以添加图像embedding缓存以避免重复计算
扩展功能
自定义提示策略
可以通过修改evaluate_single_instance函数添加新的提示类型:
def evaluate_single_instance(predictor, image, gt_mask, prompt_type):if prompt_type == "custom":# 实现自定义提示逻辑points, labels = generate_custom_prompts(gt_mask)masks, scores, _ = predictor.predict(point_coords=points,point_labels=labels,multimask_output=False)# ...
添加新的评估指标
在评估循环中添加新的指标计算:
# 在 evaluate_on_coco 函数中
pixel_accuracy = calculate_pixel_accuracy(pred_mask, gt_mask)
all_pixel_accuracies.append(pixel_accuracy)
常见问题
Q: 如何处理COCO的crowd标注?
A: 当前版本跳过了iscrowd=1的标注。如需包含,修改getAnnIds调用参数。
