sam2分割空心物体
目录
sam2分割空心物体,并求轮廓距离
分割空心物体:
sam2分割空心物体,并求轮廓距离
import argparse
import json
import os.path as osp
import timeimport numpy as np
import gc
import syssys.path.append("./sam2")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictorimport os
from glob import glob
import supervision as sv
import torch
import cv2from scipy.spatial import cKDTreedef determine_model_cfg(model_path):if "large" in model_path:return "configs/samurai/sam2.1_hiera_l.yaml"elif "base_plus" in model_path:return "configs/samurai/sam2.1_hiera_b+.yaml"elif "small" in model_path:return "configs/samurai/sam2.1_hiera_s.yaml"elif "tiny" in model_path:return "configs/samurai/sam2.1_hiera_t.yaml"else:raise ValueError("Unknown model size in path!")def top_red_points(img, top_k=5):hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)# 红色区间(两段)lower_red1 = np.array([0, 100, 100])upper_red1 = np.array([10, 255, 255])lower_red2 = np.array([170, 100, 100])upper_red2 = np.array([180, 255, 255])mask1 = cv2.inRange(hsv, lower_red1, upper_red1)mask2 = cv2.inRange(hsv, lower_red2, upper_red2)mask = cv2.bitwise_or(mask1, mask2)red_pixels = np.column_stack(np.where(mask > 0)) # (y, x)if red_pixels.size == 0:return []scores = mask[red_pixels[:, 0], red_pixels[:, 1]]sorted_idx = np.argsort(scores)[::-1]result = []for idx in sorted_idx[:top_k]:y, x = red_pixels[idx] # 注意顺序 (y,x)score = int(scores[idx])result.append((int(x), int(y)))return resultdef min_distance_point_to_contour(point, contours):"""计算 point 到所有轮廓的最小距离contours: list of np.array, 每个形状为 (N,1,2)"""min_dist = float('inf')for cnt in contours:pts = cnt[:, 0, :] # shape (N,2)dists = np.linalg.norm(pts - point, axis=1)min_dist = min(min_dist, dists.min())return min_distdef main(args):model_cfg = determine_model_cfg(args.model_path)device = "cuda:0"sam2_image_predictor = SAM2ImagePredictor(build_sam2(model_cfg, args.model_path, device=device))# sam2_image_predictor.set_image_size(1024)start = time.time()dir_a=r"D:\data\jiezhi\lunkuo_test"files=glob(os.path.join(dir_a, "*.jpg"))for img_path in files:new_img_name = os.path.basename(img_path)frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), 1)h_o, w_o = frame.shape[:2]h_center, w_center = frame.shape[:2]input_boxes = []sam2_image_predictor.set_image(frame)positive_points = top_red_points(frame,top_k=5)negative_points = [[w_center, h_center], # 负样本2]# 合并point_coords = np.array(positive_points + negative_points)# 3 个正样本(1),2 个负样本(0)point_labels = np.array([1, 1,1,1,1, 0], dtype=np.int32)masks, scores, logits = sam2_image_predictor.predict(point_coords=point_coords, point_labels=point_labels, box=None,multimask_output=False, )mask = Noneif masks.ndim == 2:mask = masksmasks = masks[None]elif masks.ndim == 3:mask = masks[0]elif masks.ndim == 4:masks = masks.squeeze(1)non_zero_indices = np.argwhere(mask > 0)annotation = {'version': "5.3.1", 'flags': {}, 'imageData': None, 'imageHeight': h_o, 'imageWidth': w_o,'imagePath': new_img_name, 'box_xie': []}if len(non_zero_indices) > 0:contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)# Parent = -1 表示外轮廓outer_contours = []inner_contours = []if hierarchy is not None:hierarchy = hierarchy[0] # 因为 findContours 返回 (contours, hierarchy)for i, h in enumerate(hierarchy):if h[3] == -1:# 没有父轮廓 → 外轮廓outer_contours.append(contours[i])else:# 有父轮廓 → 内轮廓inner_contours.append(contours[i])outer_contours = [cnt for cnt in outer_contours if len(cnt) > 100]inner_contours = [cnt for cnt in inner_contours if len(cnt) > 100]if len(outer_contours) > 1:outer_contours = [max(outer_contours, key=lambda x: len(x))]# 如果有多个内轮廓,取点最多的if len(inner_contours) > 1:inner_contours = [max(inner_contours, key=lambda x: len(x))]if len(outer_contours)==0 or len(inner_contours)==0:continueprint(f"外轮廓数量: {len(outer_contours[0])}, 内轮廓数量: {len(inner_contours[0])}")out_dis = min_distance_point_to_contour((w_center, h_center), outer_contours)in_dis = min_distance_point_to_contour((w_center, h_center), inner_contours)outer_pts = outer_contours[0][:, 0, :]inner_pts = inner_contours[0][:, 0, :]tree = cKDTree(outer_pts)dists, _ = tree.query(inner_pts)avg_dist = np.mean(dists)print("avg_dist:", avg_dist,'in_dis',in_dis,'out_dis',out_dis)vis = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)cv2.putText(vis, f'{avg_dist:.1f}', (50, 50), cv2.FONT_HERSHEY_SIMPLEX,0.8, (0, 0, 255), 2)# 绘制外轮廓(绿色)cv2.drawContours(vis, outer_contours, -1, (0, 255, 0), 2)# 绘制内轮廓(红色)cv2.drawContours(vis, inner_contours, -1, (0, 0, 255), 2)cv2.imshow("contours", vis)cv2.waitKey(0)if 0:# 注意 OpenCV 坐标 (x, y)points = non_zero_indices[:, [1, 0]].astype(np.float32)new_shape = {"label": "xie", "points": points.tolist(), "group_id": None, "description": "","shape_type": "rectangle", "flags": {}}annotation['box_xie'].append(new_shape)with open(json_path, 'w') as file:json.dump(annotation, file, indent=4)if __name__ == "__main__":parser = argparse.ArgumentParser()# parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_large.pt",)parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_small.pt",)# parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_base_plus.pt",parser.add_argument("--save_to_video", default=True, help="Save results to a video.")args = parser.parse_args()main(args)
分割空心物体:
import argparse
import json
import os.path as osp
import timeimport numpy as np
import gc
import syssys.path.append("./sam2")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictorimport os
from glob import glob
import supervision as sv
import torch
import cv2def determine_model_cfg(model_path):if "large" in model_path:return "configs/samurai/sam2.1_hiera_l.yaml"elif "base_plus" in model_path:return "configs/samurai/sam2.1_hiera_b+.yaml"elif "small" in model_path:return "configs/samurai/sam2.1_hiera_s.yaml"elif "tiny" in model_path:return "configs/samurai/sam2.1_hiera_t.yaml"else:raise ValueError("Unknown model size in path!")def top_red_points(img, top_k=5):hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)# 红色区间(两段)lower_red1 = np.array([0, 100, 100])upper_red1 = np.array([10, 255, 255])lower_red2 = np.array([170, 100, 100])upper_red2 = np.array([180, 255, 255])mask1 = cv2.inRange(hsv, lower_red1, upper_red1)mask2 = cv2.inRange(hsv, lower_red2, upper_red2)mask = cv2.bitwise_or(mask1, mask2)red_pixels = np.column_stack(np.where(mask > 0)) # (y, x)if red_pixels.size == 0:return []scores = mask[red_pixels[:, 0], red_pixels[:, 1]]sorted_idx = np.argsort(scores)[::-1]result = []for idx in sorted_idx[:top_k]:y, x = red_pixels[idx] # 注意顺序 (y,x)score = int(scores[idx])result.append((int(x), int(y)))return resultdef main(args):model_cfg = determine_model_cfg(args.model_path)device = "cuda:0"sam2_image_predictor = SAM2ImagePredictor(build_sam2(model_cfg, args.model_path, device=device))start = time.time()dir_a=r"D:\data\jiezhi\lunkuo_test"files=glob(os.path.join(dir_a, "*.jpg"))for img_path in files:new_img_name = os.path.basename(img_path)frame = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), 1)h_o, w_o = frame.shape[:2]h_center, w_center = frame.shape[:2]input_boxes = []sam2_image_predictor.set_image(frame)positive_points = top_red_points(frame,top_k=3)negative_points = [[w_center, h_center], # 负样本2]# 合并point_coords = np.array(positive_points + negative_points)# 3 个正样本(1),2 个负样本(0)point_labels = np.array([1, 1, 1, 0], dtype=np.int32)masks, scores, logits = sam2_image_predictor.predict(point_coords=point_coords, point_labels=point_labels, box=None,multimask_output=False, )mask = Noneif masks.ndim == 2:mask = masksmasks = masks[None]elif masks.ndim == 3:mask = masks[0]elif masks.ndim == 4:masks = masks.squeeze(1)non_zero_indices = np.argwhere(mask > 0)annotation = {'version': "5.3.1", 'flags': {}, 'imageData': None, 'imageHeight': h_o, 'imageWidth': w_o,'imagePath': new_img_name, 'box_xie': []}if len(non_zero_indices) > 0:contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)# Parent = -1 表示外轮廓outer_contours = []inner_contours = []if hierarchy is not None:hierarchy = hierarchy[0] # 因为 findContours 返回 (contours, hierarchy)for i, h in enumerate(hierarchy):if h[3] == -1:# 没有父轮廓 → 外轮廓outer_contours.append(contours[i])else:# 有父轮廓 → 内轮廓inner_contours.append(contours[i])print(f"外轮廓数量: {len(outer_contours)}, 内轮廓数量: {len(inner_contours)}")vis = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)# 绘制外轮廓(绿色)cv2.drawContours(vis, outer_contours, -1, (0, 255, 0), 2)# 绘制内轮廓(红色)cv2.drawContours(vis, inner_contours, -1, (0, 0, 255), 2)cv2.imshow("contours", vis)cv2.waitKey(0)if 0:# 注意 OpenCV 坐标 (x, y)points = non_zero_indices[:, [1, 0]].astype(np.float32)new_shape = {"label": "xie", "points": points.tolist(), "group_id": None, "description": "","shape_type": "rectangle", "flags": {}}annotation['box_xie'].append(new_shape)with open(json_path, 'w') as file:json.dump(annotation, file, indent=4)if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("--model_path", default=r"D:\data\models\sam2.1_hiera_large.pt",help="Path to the model checkpoint.")parser.add_argument("--save_to_video", default=True, help="Save results to a video.")args = parser.parse_args()main(args)
