当前位置: 首页 > news >正文

sam2 点选 分割图片 2025

sam2 点选 分割图片

轮廓越接近矩形,轮廓点越少

正方形的轮廓点个数会小于100,比如:

外轮廓数量: 42, 内轮廓数量: 28

import argparse
import json
import os.path as osp
import timeimport numpy as np
import gc
import syssys.path.append("./sam2")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictorimport os
from glob import glob
import supervision as sv
import torch
import cv2from scipy.spatial import cKDTreedef determine_model_cfg(model_path):if "large" in model_path:return "configs/samurai/sam2.1_hiera_l.yaml"elif "base_plus" in model_path:return "configs/samurai/sam2.1_hiera_b+.yaml"elif "small" in model_path:return "configs/samurai/sam2.1_hiera_s.yaml"elif "tiny" in model_path:return "configs/samurai/sam2.1_hiera_t.yaml"else:raise ValueError("Unknown model size in path!")def top_red_points(img, top_k=5):hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)# 红色区间(两段)lower_red1 = np.array([0, 100, 100])upper_red1 = np.array([10, 255, 255])lower_red2 = np.array([170, 100, 100])upper_red2 = np.array([180, 255, 255])mask1 = cv2.inRange(hsv, lower_red1, upper_red1)mask2 = cv2.inRange(hsv, lower_red2, upper_red2)mask = cv2.bitwise_or(mask1, mask2)red_pixels = np.column_stack(np.where(mask > 0))  # (y, x)if red_pixels.size == 0:return []scores = mask[red_pixels[:, 0], red_pixels[:, 1]]sorted_idx = np.argsort(scores)[::-1]result = []for idx in sorted_idx[:top_k]:y, x = red_pixels[idx]  # 注意顺序 (y,x)score = int(scores[idx])result.append((int(x), int(y)))return resultdef min_distance_point_to_contour(point, contours):"""计算 point 到所有轮廓的最小距离contours: list of np.array, 每个形状为 (N,1,2)"""min_dist = float('inf')for cnt in contours:pts = cnt[:, 0, :]  # shape (N,2)dists = np.linalg.norm(pts - point, axis=1)min_dist = min(min_dist, dists.min())return min_distclass MultiPointDrawer:def __init__(self, img):self.img = img.copy()self.display = img.copy()self.points = []   # (x, y)self.labels = []   # 1=positive, 0=negativecv2.namedWindow("Click Points (L = + , R = - , Enter = OK)")cv2.setMouseCallback("Click Points (L = + , R = - , Enter = OK)", self.mouse_callback)def mouse_callback(self, event, x, y, flags, param):if event == cv2.EVENT_LBUTTONDOWN:# 正样本点(红色)self.points.append((x, y))self.labels.append(1)cv2.circle(self.display, (x, y), 4, (0, 0, 255), -1)  # 红色点elif event == cv2.EVENT_RBUTTONDOWN:# 负样本点(蓝色)self.points.append((x, y))self.labels.append(0)cv2.circle(self.display, (x, y), 4, (255, 0, 0), -1)  # 蓝色点def run(self):while True:cv2.imshow("Click Points (L = + , R = - , Enter = OK)", self.display)key = cv2.waitKey(1)if key == 13:    # Enterbreakif key == 27:    # Esc 清空退出self.points.clear()self.labels.clear()breakcv2.destroyWindow("Click Points (L = + , R = - , Enter = OK)")return self.points, self.labelsdef main(args):model_cfg = determine_model_cfg(args.model_path)device = "cuda:0"sam2_image_predictor = SAM2ImagePredictor(build_sam2(model_cfg, args.model_path, device=device))# sam2_image_predictor.set_image_size(1024)start = time.time()dir_a=r"D:\data\pred_res\1107_2107"files=glob(os.path.join(dir_a, "*.jpg"))for img_path in files:new_img_name = os.path.basename(img_path)frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), 1)h_o, w_o = frame.shape[:2]h_center, w_center = frame.shape[:2]input_boxes = []sam2_image_predictor.set_image(frame)drawer = MultiPointDrawer(frame)clicked_points, clicked_labels = drawer.run()if len(clicked_points) == 0:print("⚠ 未选择任何点,跳过此图")continue# 转 numpypoint_coords = np.array(clicked_points, dtype=np.float32)point_labels = np.array(clicked_labels, dtype=np.int32)masks, scores, logits = sam2_image_predictor.predict(point_coords=point_coords, point_labels=point_labels, box=None,multimask_output=False, )mask = Noneif masks.ndim == 2:mask = masksmasks = masks[None]elif masks.ndim == 3:mask = masks[0]elif masks.ndim == 4:masks = masks.squeeze(1)mask_img = mask.astype(np.uint8)non_zero_indices = np.argwhere(mask > 0)vis = frame.copy()for (y, x) in non_zero_indices:cv2.circle(vis, (x, y), 1, (0, 255, 0), -1)  # 小绿点cv2.imshow("Mask Points", vis)cv2.waitKey(0)annotation = {'version': "5.3.1", 'flags': {}, 'imageData': None, 'imageHeight': h_o, 'imageWidth': w_o,'imagePath': new_img_name, 'box_xie': []}if len(non_zero_indices) > 0:contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)# Parent = -1 表示外轮廓outer_contours = []inner_contours = []if hierarchy is not None:hierarchy = hierarchy[0]  # 因为 findContours 返回 (contours, hierarchy)for i, h in enumerate(hierarchy):if h[3] == -1:# 没有父轮廓 → 外轮廓outer_contours.append(contours[i])else:# 有父轮廓 → 内轮廓inner_contours.append(contours[i])outer_contours = [cnt for cnt in outer_contours if len(cnt) > 10]inner_contours = [cnt for cnt in inner_contours if len(cnt) > 10]if len(outer_contours) > 1:outer_contours = [max(outer_contours, key=lambda x: len(x))]# 如果有多个内轮廓,取点最多的if len(inner_contours) > 1:inner_contours = [max(inner_contours, key=lambda x: len(x))]if len(outer_contours)==0 or len(inner_contours)==0:print('----none----')continueprint(f"外轮廓数量: {len(outer_contours[0])}, 内轮廓数量: {len(inner_contours[0])}")out_dis = min_distance_point_to_contour((w_center, h_center), outer_contours)in_dis = min_distance_point_to_contour((w_center, h_center), inner_contours)outer_pts = outer_contours[0][:, 0, :]inner_pts = inner_contours[0][:, 0, :]tree = cKDTree(outer_pts)dists, _ = tree.query(inner_pts)avg_dist = np.mean(dists)print("avg_dist:", avg_dist,'in_dis',in_dis,'out_dis',out_dis)vis = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)cv2.putText(vis, f'{avg_dist:.1f}', (50, 50), cv2.FONT_HERSHEY_SIMPLEX,0.8, (0, 0, 255), 2)# 绘制外轮廓(绿色)cv2.drawContours(vis, outer_contours, -1, (0, 255, 0), 2)# 绘制内轮廓(红色)cv2.drawContours(vis, inner_contours, -1, (0, 0, 255), 2)cv2.imshow("contours", vis)cv2.waitKey(0)if 0:# 注意 OpenCV 坐标 (x, y)points = non_zero_indices[:, [1, 0]].astype(np.float32)new_shape = {"label": "xie", "points": points.tolist(), "group_id": None, "description": "","shape_type": "rectangle", "flags": {}}annotation['box_xie'].append(new_shape)with open(json_path, 'w') as file:json.dump(annotation, file, indent=4)if __name__ == "__main__":parser = argparse.ArgumentParser()# parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_large.pt",)parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_small.pt",)# parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_base_plus.pt",parser.add_argument("--save_to_video", default=True, help="Save results to a video.")args = parser.parse_args()main(args)

http://www.dtcms.com/a/581215.html

相关文章:

  • 网站开发源程序重庆建筑人才网官网
  • 如何理解蒙特卡洛方法并用python进行模拟
  • 公司网站代码模板wordpress 虎嗅网
  • 在 Windows 中清理依赖node_modules并重新安装
  • 【数据结构】从零开始认识图论 --- 并查集与最小生成树算法
  • 使用 AWS WAF 防护 Stored XSS 攻击完整指南
  • 当爬虫遇到GraphQL:如何分析与查询这种新型API?
  • 游戏手柄遥控越疆协作机器人[一]
  • MATLAB实现决策树数值预测
  • Maven 多模块项目与 Spring Boot 结合指南
  • 搜索量最高的网站小白学编程应该从哪里开始学
  • 西安大型网站制作wordpress耗时显示
  • Kubernetes(k8s)
  • 如何提高 SaaS 产品的成功率?
  • ​技术融合新纪元:深度学习、大数据与云原生的跨界实践
  • 中国高分辨率单季稻种植分布数据集(2017-2023)
  • PDF工具箱/合并拆分pdf/提取图片
  • 如何在PDF文档中打钩?(福昕阅读器)打√
  • 新手怎么样学做网站企业网站建设规划的基本原则是什么
  • 【DIY】PCB练习记录2——51单片机核心板
  • Spring Boot 2.7.18(最终 2.x 系列版本)3 - 枚举规范定义:定义基础枚举接口;定义枚举工具类;示例枚举
  • aspnet东莞网站建设多少钱frontpage怎样做网站
  • uniapp 使用renderjs 封装 video-player 视频播放器, html5视频播放器-解决视频层级、覆盖、播放卡顿
  • 基于深度对比学习的分析化学结构注释TOP1匹配率提升研究
  • MFA MACOS 安装流程
  • Ubuntu 上部署 Microsoft SQL Server 详细教程
  • 网站上面怎么做链接微信网站合同
  • 关于网站建设与维护的参考文献phpcms 恢复网站
  • 圆角边框+阴影
  • Android14 init.environ.rc详解