当前位置: 首页 > news >正文

samurai 点选分割 box分割

samurai 点选分割 box分割 合并兼容了,源代码

# coding=utf-8
import argparse
import json
import os
import time
import numpy as np
import torch
import cv2
import syscurrent_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(current_dir)
print('current_dir', current_dir)
paths = [current_dir, current_dir + '/../']
paths.append(os.path.join(current_dir, 'sam2'))
for path in paths:sys.path.insert(0, path)os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictorclass PointSelector:"""用于交互式选择前景和背景点的辅助类"""def __init__(self, image):self.image_original = image.copy()self.image = self.image_original.copy()self.temp_image = self.image.copy()self.points = []  # 存储点坐标 [(x, y, label)],label=1表示前景,0表示背景self.done = Falsedef draw_points(self, event, x, y, flags, param):"""鼠标事件回调函数,处理点的选择和绘制"""self.temp_image = self.image.copy()# 绘制十字准星cv2.line(self.temp_image, (x, 0), (x, self.temp_image.shape[0]), (0, 255, 0), 1)cv2.line(self.temp_image, (0, y), (self.temp_image.shape[1], y), (0, 255, 0), 1)# 绘制已选择的点for point_x, point_y, label in self.points:color = (0, 255, 0) if label == 1 else (0, 0, 255)  # 绿色前景,红色背景cv2.circle(self.temp_image, (point_x, point_y), 5, color, -1)cv2.circle(self.temp_image, (point_x, point_y), 8, color, 2)if event == cv2.EVENT_LBUTTONDOWN:# 左键添加前景点self.points.append((x, y, 1))cv2.circle(self.image, (x, y), 5, (0, 255, 0), -1)cv2.circle(self.image, (x, y), 8, (0, 255, 0), 2)print(f"Added foreground point at ({x}, {y})")elif event == cv2.EVENT_RBUTTONDOWN:# 右键添加背景点self.points.append((x, y, 0))cv2.circle(self.image, (x, y), 5, (0, 0, 255), -1)cv2.circle(self.image, (x, y), 8, (0, 0, 255), 2)print(f"Added background point at ({x}, {y})")elif event == cv2.EVENT_MBUTTONDOWN:# 中键完成选择self.done = Truedef run(self):"""运行点选择界面,返回选择的点坐标和标签"""cv2.namedWindow('Select Points (L:foreground, R:background, M:done)')cv2.setMouseCallback('Select Points (L:foreground, R:background, M:done)', self.draw_points)print("Instructions:")print("- Left click: Add foreground point (object)")print("- Right click: Add background point (not object)")print("- Middle click: Finish point selection")print("- Backspace: Remove last point")print("- Esc: Cancel selection")while True:cv2.imshow('Select Points (L:foreground, R:background, M:done)', self.temp_image)key = cv2.waitKey(1) & 0xFFif key == 27:  # Esc键退出self.points = []breakelif key == 13 or self.done:  # 回车键或中键完成breakelif key == 8 and self.points:  # Backspace删除最后一个点self.points.pop()self.image = self.image_original.copy()for point_x, point_y, label in self.points:color = (0, 255, 0) if label == 1 else (0, 0, 255)cv2.circle(self.image, (point_x, point_y), 5, color, -1)cv2.circle(self.image, (point_x, point_y), 8, color, 2)cv2.destroyAllWindows()# 分离坐标和标签if self.points:points_coords = np.array([[x, y] for x, y, label in self.points])points_labels = np.array([label for x, y, label in self.points])return points_coords, points_labelselse:return None, Noneclass ROISelector:"""用于交互式选择ROI区域的辅助类"""def __init__(self, image):self.image_original = image.copy()self.image = self.image_original.copy()self.temp_image = self.image.copy()self.drawing = Falseself.ix, self.iy = -1, -1self.roi = Noneself.done = Falsedef draw_roi(self, event, x, y, flags, param):"""鼠标事件回调函数,处理ROI的选择和绘制"""self.temp_image = self.image.copy()# 绘制十字准星cv2.line(self.temp_image, (x, 0), (x, self.temp_image.shape[0]), (0, 255, 0), 1)cv2.line(self.temp_image, (0, y), (self.temp_image.shape[1], y), (0, 255, 0), 1)# 绘制已选择的ROIif self.roi is not None:x1, y1, w, h = self.roicv2.rectangle(self.temp_image, (x1, y1), (x1 + w, y1 + h), (255, 0, 0), 2)if event == cv2.EVENT_LBUTTONDOWN:self.drawing = Trueself.ix, self.iy = x, yself.image = self.image_original.copy()self.roi = Noneelif event == cv2.EVENT_MOUSEMOVE:if self.drawing:self.temp_image = self.image.copy()cv2.rectangle(self.temp_image, (self.ix, self.iy), (x, y), (255, 0, 0), 2)elif event == cv2.EVENT_LBUTTONUP:self.drawing = Falsecv2.rectangle(self.image, (self.ix, self.iy), (x, y), (255, 0, 0), 2)self.roi = (min(self.ix, x), min(self.iy, y), abs(x - self.ix), abs(y - self.iy))print("ROI selected:", self.roi)elif event == cv2.EVENT_RBUTTONDOWN:# 右键完成选择self.done = Truedef run(self):"""运行ROI选择界面,返回选择的ROI坐标"""cv2.namedWindow('Select ROI (L:drag, R:done)')cv2.setMouseCallback('Select ROI (L:drag, R:done)', self.draw_roi)print("Instructions:")print("- Left click and drag: Select ROI")print("- Right click: Finish selection")print("- Esc: Cancel selection")while True:cv2.imshow('Select ROI (L:drag, R:done)', self.temp_image)key = cv2.waitKey(1) & 0xFFif key == 27:  # Esc键退出self.roi = Nonebreakelif key == 13 or self.done:  # 回车键或右键完成breakcv2.destroyAllWindows()if self.roi:# 转换为[x1, y1, x2, y2]格式x, y, w, h = self.roireturn np.array([[x, y, x + w, y + h]])else:return Noneclass SAM2Segmenter:"""基于SAM2的图像分割工具类,支持点选和ROI两种模式"""def __init__(self, model_path, device="cuda:0"):self.model_path = model_pathself.device = deviceself.mode = "point"  # 默认点选模式self.model_cfg = self.determine_model_cfg()self.predictor = self.initialize_predictor()def determine_model_cfg(self):"""根据模型路径确定配置文件"""if "large" in self.model_path:return "configs/samurai/sam2.1_hiera_l.yaml"elif "base_plus" in self.model_path:return "configs/samurai/sam2.1_hiera_b+.yaml"elif "small" in self.model_path:return "configs/samurai/sam2.1_hiera_s.yaml"elif "tiny" in self.model_path:return "configs/samurai/sam2.1_hiera_t.yaml"else:raise ValueError("Unknown model size in path!")def initialize_predictor(self):"""初始化SAM2预测器"""return SAM2ImagePredictor(build_sam2(self.model_cfg, self.model_path, device=self.device))def set_mode(self, mode):"""设置分割模式:point或roi"""if mode in ["point", "roi"]:self.mode = modeprint(f"Mode set to: {mode}")else:raise ValueError("Mode must be 'point' or 'roi'")def load_image(self, img_path):"""加载图像并返回图像数据和相关信息"""self.img_path = img_pathself.new_img_name = os.path.basename(img_path)self.json_path = os.path.splitext(img_path)[0] + f"_{self.mode}_outline.json"# 读取图像(支持中文路径)self.frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), 1)self.height, self.width = self.frame.shape[:2]return self.framedef select_input(self):"""根据模式启动相应的选择界面"""if self.mode == "point":point_selector = PointSelector(self.frame)self.points_coords, self.points_labels = point_selector.run()self.box = Nonereturn self.points_coords is not Noneelse:  # roi模式roi_selector = ROISelector(self.frame)self.box = roi_selector.run()self.points_coords, self.points_labels = None, Nonereturn self.box is not Nonedef perform_segmentation(self):"""执行图像分割"""self.predictor.set_image(self.frame)if self.mode == "point":# 点选模式:使用点提示point_coords = torch.Tensor(self.points_coords).unsqueeze(0) if self.points_coords is not None else Nonepoint_labels = torch.Tensor(self.points_labels).unsqueeze(0) if self.points_labels is not None else Nonebox = Noneelse:# ROI模式:使用框提示point_coords = Nonepoint_labels = Nonebox = torch.Tensor(self.box) if self.box is not None else None# 预测掩码masks, scores, _ = self.predictor.predict(point_coords=point_coords,point_labels=point_labels,box=box,multimask_output=True  # 输出多个mask供选择)# 选择分数最高的maskbest_mask_idx = np.argmax(scores)self.mask = masks[best_mask_idx]self.score = scores[best_mask_idx]print(f"Best mask score: {self.score:.3f}")return self.mask, self.scoredef create_annotation(self):"""创建标注数据并保存为JSON文件"""# 找到mask的轮廓mask_uint8 = (self.mask * 255).astype(np.uint8)contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)# 构建标注数据annotation = {'version': "5.3.1",'flags': {},'imageData': None,'imageHeight': self.height,'imageWidth': self.width,'imagePath': self.new_img_name,'mode': self.mode,'mask_score': float(self.score),'shapes': []}# 添加输入信息if self.mode == "point":annotation['points'] = self.points_coords.tolist() if self.points_coords is not None else []annotation['point_labels'] = self.points_labels.tolist() if self.points_labels is not None else []else:annotation['roi'] = self.box.tolist() if self.box is not None else []# 处理每个轮廓for contour in contours:if cv2.contourArea(contour) > 100:  # 过滤小面积轮廓# 多边形近似epsilon = 0.002 * cv2.arcLength(contour, True)approx = cv2.approxPolyDP(contour, epsilon, True)# 转换为点列表points = approx.reshape(-1, 2).tolist()shape = {"label": "obj","points": points,"group_id": None,"description": "","shape_type": "polygon","flags": {}}annotation['shapes'].append(shape)# 在图像上绘制轮廓cv2.drawContours(self.frame, [approx], -1, (0, 255, 0), 2)# 保存标注文件with open(self.json_path, 'w') as file:json.dump(annotation, file, indent=4)print(f"Annotation saved to {self.json_path}")return annotationdef visualize_results(self):"""可视化分割结果"""# 创建带透明度的mask叠加color_mask = np.zeros_like(self.frame)color_mask[self.mask > 0.5] = [0, 255, 0]  # 绿色maskalpha = 0.4blended = cv2.addWeighted(self.frame, 1, color_mask, alpha, 0)# 根据模式绘制输入信息if self.mode == "point":# 绘制点for (x, y), label in zip(self.points_coords, self.points_labels):color = (0, 255, 0) if label == 1 else (0, 0, 255)cv2.circle(blended, (int(x), int(y)), 8, color, -1)cv2.circle(blended, (int(x), int(y)), 12, color, 2)else:# 绘制ROI框if self.box is not None:x1, y1, x2, y2 = self.box[0]cv2.rectangle(blended, (x1, y1), (x2, y2), (255, 0, 0), 2)# 显示结果cv2.imshow('Segmentation Result', blended)cv2.imshow('Original Image', self.frame)print("Press any key to exit...")cv2.waitKey(0)cv2.destroyAllWindows()def process_image(self):"""处理单张图像的完整流程"""start_time = time.time()# 选择输入模式if not self.select_input():print("No input selected. Exiting.")return# 执行分割self.perform_segmentation()# 创建并保存标注self.create_annotation()print(f"Total processing time: {time.time() - start_time:.2f} seconds")def release(self):"""释放所有占用的资源,特别是CUDA内存"""print("Releasing resources...")# 释放OpenCV窗口cv2.destroyAllWindows()# 释放图像和掩码数据self.frame = Noneself.mask = Noneself.points_coords = Noneself.points_labels = Noneself.box = None# 释放PyTorch相关资源if hasattr(self, 'predictor') and self.predictor is not None:# 清除预测器中的图像特征if hasattr(self.predictor, 'reset_image'):self.predictor.reset_image()# 释放模型占用的CUDA内存if hasattr(self.predictor, 'model'):if hasattr(self.predictor.model, 'to'):# 将模型移到CPU以释放GPU内存self.predictor.model = self.predictor.model.to('cpu')# 清理PyTorch缓存if torch.cuda.is_available():torch.cuda.empty_cache()torch.cuda.ipc_collect()# 强制垃圾回收import gcgc.collect()print("Resources released successfully")if __name__ == "__main__":parser = argparse.ArgumentParser()if sys.platform.startswith('win'):parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_large.pt", help="checkpoint.")else:parser.add_argument("--model_path", default=r"/home/lbg/data/models/seg/sam2.1_hiera_large.pt",help="checkpoint.")parser.add_argument("--image_path", default=r"0904_1447_405.jpg", help="image to process.")parser.add_argument("--device", default="cuda:0", help="Device to use (cuda or cpu)")parser.add_argument("--mode", default="roi", choices=["point", "roi"], help="Segmentation mode: point or roi")args = parser.parse_args()val_release = 0# 创建分割器实例并处理图像segmenter = SAM2Segmenter(args.model_path, args.device)segmenter.set_mode(args.mode)segmenter.load_image(args.image_path)if sys.platform.startswith('win'):segmenter.process_image()else:# Linux环境下使用预设点if args.mode == "point":points_coords = np.array([[432, 167], [339, 264]])points_labels = np.array([1, 1])segmenter.points_coords = points_coordssegmenter.points_labels = points_labelssegmenter.box = Noneelse:# ROI模式使用预设框segmenter.box = np.array([[200, 100, 400, 300]])segmenter.points_coords, segmenter.points_labels = None, Nonesegmenter.perform_segmentation()segmenter.create_annotation()if sys.platform.startswith('win'):segmenter.visualize_results()if val_release:segmenter.release()

http://www.dtcms.com/a/389854.html

相关文章:

  • 计算机架构的总线协议中的等待状态是什么?
  • C++:入门基础(1)
  • ACD智能分配:服务延续和专属客服设置
  • 自监督学习分割
  • 抛弃自定义模态框:原生Dialog的实力
  • LangGraph 简单入门介绍
  • Docker 部署 DzzOffice:服务器 IP 转发功能是否需要开启
  • 无人机避障——卡内基梅隆大学(CMU)CERLAB 无人机自主框架复现
  • 正点原子zynq_FPGA-初识ZYNQ
  • Vue3中对比ref,reactive,shallowRef,shallowReactive
  • 通过Freemark渲染数据到Word里并生成压缩包
  • Vue 项目中使用 AbortController:解决请求取消、超时与内存泄漏问题
  • 设置管家婆服务器开机自动启动
  • ubuntu20 安装 ros2 foxy
  • 二分查找(二分查找算法)
  • 贪心算法应用:超图匹配问题详解
  • Hadoop3.3.5搭建指南(双NN版本)
  • 如何正确写Controller?参数校验、异常处理
  • 线性代数:LU与Cholesky分解
  • 饮用水在线监测设备:实时、精准地捕捉水体中的关键参数,为供水安全提供全方位保障
  • 【环境搭建】Conda安装教程
  • Java与机器学习的结合:库与应用!
  • DHCP基本原理及实验(ENSP配置)
  • 高系分十一:软件需求工程
  • MCP Server Chart AntV 项目解析
  • 2025药物市场调研分析案例(模板资源分享)
  • 飞网出口网关:安全便捷地访问受限资源
  • 大模型训练的三大显存优化策略
  • 动态加载js链接、异步传参加载组件、有趣打印
  • 【Python】Python异常、模块与包