LLM - 搭建 Grounded SAM 2 模型的视觉检测与分割服务 API
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/150272191
免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。
GitHub:https://github.com/IDEA-Research/Grounded-SAM-2
Grounded-SAM-2 是基于 Grounding DINO 和 SAM2 (Segment Anything Model) 的联合图像分割预测系统。通过 FastAPI 库与 Grounded-SAM-2 源码,搭建算法的 API 服务。
参考文档:
- 使用 Docker 配置 PyTorch 研发环境
- 开源视觉分割算法框架 Grounded SAM2 配置与推理
启动环境目录:
/root
mkdir workspace
cd workspace
1. 配置环境
1.1 代理加速
使用免费的 GitHub 在线代理,即:https://ghproxy.link/
# https://ghfast.top
git clone https://ghfast.top/https://github.com/hiyouga/LLaMA-Factory.git # 示例
注意:免费代理可能失效,需要实时查看。
使用免费的 HuggingFace 在线代理,参考:https://hf-mirror.com/:
export HF_ENDPOINT=https://hf-mirror.com
1.2 MicroMamba (Conda)
因 Conda 存在合规风险,建议使用 MicroMamba。Mamba 安装文件需要【修改 GitHub 代理】 ,即 mamba_install.sh。
配置 Mamba,注意初始化脚本位于 /root/.zshrc
,日志如下:
vim mamba_install.sh
bash mamba_install.shMicromamba binary folder? [~/.local/bin]
Init shell (zsh)? [Y/n]
Configure conda-forge? [Y/n]
Prefix location? [~/micromamba]
Running `shell init`, which:- modifies RC file: "/root/.zshrc"- generates config for root prefix: "/root/micromamba"- sets mamba executable to: "/root/.local/bin/micromamba"
The following has been added in your "/root/.zshrc" file# >>> mamba initialize >>>
# !! Contents within this block are managed by 'micromamba shell init' !!
export MAMBA_EXE='/root/.local/bin/micromamba';
export MAMBA_ROOT_PREFIX='/root/micromamba';
__mamba_setup="$("$MAMBA_EXE" shell hook --shell zsh --root-prefix "$MAMBA_ROOT_PREFIX" 2> /dev/null)"
if [ $? -eq 0 ]; theneval "$__mamba_setup"
elsealias micromamba="$MAMBA_EXE" # Fallback on help from micromamba activate
fi
unset __mamba_setup
# <<< mamba initialize <<<Please restart your shell to activate micromamba or run the following:\nsource ~/.bashrc (or ~/.zshrc, ~/.xonshrc, ~/.config/fish/config.fish, ...)
使用 Mamba 构建 Python 环境,即:
source ~/.zshrc
micromamba create -n g-sam2 python=3.11 # 创建含指定包的环境
micromamba activate g-sam2 # 激活环境
python --version # 版本
which python# micromamba activate g-sam2
1.3 配置 PyTorch
需要检查 CUDA 版本 (NVCC, NVIDIA Cuda Compiler),即 11.8 版本,如下:
nvcc --versionnvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
因此,安装相应版本 PyTorch,注意,默认版本是 12.8,需要指定版本为 11.8,即:
pip install torch==2.7.1+cu118 torchvision==0.22.1+cu118 torchaudio==2.7.1+cu118 --index-url https://download.pytorch.org/whl/cu118 --force-reinstall
测试 PyTorch 是否可用,即:
import torch
print(torch.__version__) # 2.7.1+cu118
print(torch.cuda.is_available()) # True
exit()
1.4 配置环境变量
环境变量如下:
WORK_DIR="/root/workspace"HF_ENDPOINT="https://hf-mirror.com"
TORCH_HOME="$WORK_DIR/torch_home/"
HF_HOME="$WORK_DIR/huggingface/"
HUGGINGFACE_TOKEN="hf_xxxxxx"MODELSCOPE_CACHE="$WORK_DIR/modelscope_models/"
MODELSCOPE_API_TOKEN="ms-xxxxxxx"CUDA_HOME="/usr/local/cuda"
PATH="/usr/local/cuda/bin:$PATH"
LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH"OMP_NUM_THREADS=64
HYDRA_FULL_ERROR=1
TORCH_CUDA_ARCH_LIST=9.0
自定义配置:
- Huggingface 的 Access Tokens 参考:https://huggingface.co/settings/tokens
- ModelScope 的 Access Tokens 参考:https://modelscope.cn/my/myaccesstoken
验证 TORCH_CUDA_ARCH_LIST
:
# 确保使用兼容的 PyTorch 版本
import torch
print(torch.cuda.get_device_capability()) # 应返回 (9, 0) 或更高
验证 CUDA_HOME
:
ll /usr/local/cuda/usr/local/cuda -> /etc/alternatives/cuda
2. 配置工程
下载 Grounded-SAM-2 的工程:
git clone https://ghfast.top/https://github.com/IDEA-Research/Grounded-SAM-2.git
下载依赖的模型,即 SAM2 与 GroundingDINO,即:
cd checkpoints
bash download_ckpts.sh
cd gdino_checkpoints
bash download_ckpts.sh
注意:修改
gdino_checkpoints/download_ckpts.sh
的BASE_URL
使用代理模式,提升下载速度。
安装相关依赖库:
cd grounding_dino
pip install -r requirements.txt
安装本地工程库:
# 默认版本是 transformers==4.55.0,需要降级版本至4.49.0
pip install transformers==4.49.0# 安装 iopath
pip install iopath# 安装 hydra-core
pip install hydra-core --upgrade --pre# 安装 工程库
pip install --no-build-isolation -e .
pip install --no-build-isolation -e grounding_dino
Transformers 升级,参考:GitHub - Some problem with OmDetTurboProcessor.post_process_grounded_object_detection()
遇到问题:UserWarning: Failed to load custom C++ ops. Running on CPU mode Only!
,CUDA 环境变量配置错误,即:
CUDA_HOME="/usr/local/cuda"
PATH="/usr/local/cuda/bin:$PATH"
LD_LIBRARY_PATH="/usr/local/cuda/lib64:$LD_LIBRARY_PATH"
测试 Demo 工程:
# 运行正常
python3 grounded_sam2_local_demo.py# 运行正常
python3 grounded_sam2_tracking_demo.py
在图像检测 (grounded_sam2_local_demo.py
) 中,增加 NMS 逻辑,即:
from torchvision.ops import box_convert, nmsNMS_THRESHOLD = 0.5# process the box prompt for SAM 2
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy")# NMS 逻辑
nms_indices = nms(input_boxes, confidences, NMS_THRESHOLD)
input_boxes = input_boxes[nms_indices].numpy()
confidences = confidences[nms_indices]
labels = [labels[i] for i in nms_indices]
3. 搭建服务
研发视觉检测与分割 ImgPredictor 类,即:
import os
import sys
import tempfileimport cv2
from loguru import logger
import numpy as np
import pycocotools.mask as mask_util
import requests
import torch
from torchvision.ops import box_convert, nmssys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
from root_dir import ROOT_DIR
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictorclass ImgPredictor:def __init__(self):# Model configurationself.SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"self.SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"self.GROUNDING_DINO_CONFIG = ("grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py")self.GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth"self.BOX_THRESHOLD = 0.35self.TEXT_THRESHOLD = 0.25self.NMS_THRESHOLD = 0.5self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"# Initialize modelsself._init_models()# Run warmup prediction to avoid slow first runself._warmup()def _init_models(self):"""Initialize SAM2 and Grounding DINO models"""# Build SAM2 image predictorsam2_model = build_sam2(self.SAM2_MODEL_CONFIG, self.SAM2_CHECKPOINT, device=self.DEVICE)self.sam2_predictor = SAM2ImagePredictor(sam2_model)# Build grounding dino modelself.grounding_model = load_model(model_config_path=self.GROUNDING_DINO_CONFIG,model_checkpoint_path=self.GROUNDING_DINO_CHECKPOINT,device=self.DEVICE,)def _warmup(self):"""Run a warmup prediction to avoid slow first run"""try:# Use the same test image and prompt from the demotemp_img_path = os.path.join(ROOT_DIR, "notebooks/images/truck.jpg")text_prompt = "car. tire."# Check if the demo image exists, if not create a temporary oneif not os.path.exists(temp_img_path):# Create a temporary test imagetemp_img = np.zeros((480, 640, 3), dtype=np.uint8)temp_img[:] = (128, 128, 128) # Gray image# Save to temporary filewith tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:cv2.imwrite(tmp_file.name, temp_img)temp_img_path = tmp_file.nametemp_file_created = Trueelse:temp_file_created = Falselogger.info(f"Warmup prediction: {temp_img_path}, {text_prompt}")try:# Load imageimage_source, image = load_image(temp_img_path)# Set image for SAM2 predictorself.sam2_predictor.set_image(image_source)# Run Grounding DINO prediction with demo prompt# Disable autocast to avoid dtype mismatch between BFloat16 and Halfboxes, confidences, labels = predict(model=self.grounding_model,image=image,caption=text_prompt,box_threshold=self.BOX_THRESHOLD,text_threshold=self.TEXT_THRESHOLD,device=self.DEVICE,)# If boxes are found, run SAM2 predictionif len(boxes) > 0:h, w, _ = image_source.shapeboxes = boxes * torch.Tensor([w, h, w, h])input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()# Run SAM2 predictionmasks, scores, logits = self.sam2_predictor.predict(point_coords=None,point_labels=None,box=input_boxes,multimask_output=False,)finally:# Clean up temporary file only if we created itif temp_file_created and os.path.exists(temp_img_path):try:os.unlink(temp_img_path)except:passexcept Exception as e:# Warmup failed, but continue - just log the errorprint(f"Warmup prediction failed: {str(e)}")def _download_image(self, img_url: str) -> str:"""Download image from URL and save to temporary file"""response = requests.get(img_url)response.raise_for_status()# Create temporary filewith tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:tmp_file.write(response.content)return tmp_file.namedef _single_mask_to_rle(self, mask):"""Convert single mask to RLE format"""rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]rle["counts"] = rle["counts"].decode("utf-8")return rledef _rle_to_mask(self, rle):"""Convert RLE format to mask"""return mask_util.decode(rle)def predict(self, img_url: str, prompts: str):"""Predict objects in image using Grounding DINO + SAM2Args:img_url: URL of the image to processprompts: Text prompts for object detection (should be lowercased and end with dot)Returns:dict: Results containing annotations with bboxes, masks, and scores"""temp_img_path = Nonetry:# Download imagetemp_img_path = self._download_image(img_url)# Load imageimage_source, image = load_image(temp_img_path)# Set image for SAM2 predictorself.sam2_predictor.set_image(image_source)# Run Grounding DINO predictionboxes, confidences, labels = predict(model=self.grounding_model,image=image,caption=prompts,box_threshold=self.BOX_THRESHOLD,text_threshold=self.TEXT_THRESHOLD,device=self.DEVICE,)# Check if any objects were detectedif len(boxes) == 0:# No objects detected, return empty resultsh, w, _ = image_source.shapereturn {"image_url": img_url,"annotations": [],"box_format": "xyxy","img_width": w,"img_height": h,"prompts": prompts,}# process the box prompt for SAM 2h, w, _ = image_source.shapeboxes = boxes * torch.Tensor([w, h, w, h])input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy")# NMSnms_indices = nms(input_boxes, confidences, self.NMS_THRESHOLD)input_boxes = input_boxes[nms_indices].numpy()confidences = confidences[nms_indices]labels = [labels[i] for i in nms_indices]# Run SAM2 predictionmasks, scores, logits = self.sam2_predictor.predict(point_coords=None,point_labels=None,box=input_boxes,multimask_output=False,)# Post-process masksif masks.ndim == 4:masks = masks.squeeze(1)# Convert to standard formatconfidences = confidences.numpy().tolist()class_names = labelsinput_boxes = input_boxes.tolist()scores = scores.tolist()# Convert masks to RLE formatmask_rles = [self._single_mask_to_rle(mask) for mask in masks]# Prepare resultsresults = {"image_url": img_url,"annotations": [{"class_name": class_name,"bbox": box,"segmentation": mask_rle,"score": score,"confidence": confidence,}for class_name, box, mask_rle, score, confidence in zip(class_names, input_boxes, mask_rles, scores, confidences)],"box_format": "xyxy","img_width": w,"img_height": h,"prompts": prompts,}return resultsexcept Exception as e:raise Exception(f"Prediction failed: {str(e)}")finally:# Clean up temporary fileif temp_img_path and os.path.exists(temp_img_path):try:os.unlink(temp_img_path)except:passdef main():predictor = ImgPredictor()result = predictor.predict("xxx","page. ",)print(result)if __name__ == "__main__":main()
启动 Grounded-SAM-2 的 API 服务:
#!/bin/bash# mamba
export MAMBA_EXE='/root/.local/bin/micromamba';
export MAMBA_ROOT_PREFIX='/root/micromamba';
__mamba_setup="$("$MAMBA_EXE" shell hook --shell zsh --root-prefix "$MAMBA_ROOT_PREFIX" 2> /dev/null)"
if [ $? -eq 0 ]; theneval "$__mamba_setup"
elsealias micromamba="$MAMBA_EXE" # Fallback on help from micromamba activate
fi
unset __mamba_setup# 激活 micromamba 环境
eval "$(micromamba shell hook --shell bash)"
micromamba activate g-sam2# 启动API服务
cd /root/workspace/Grounded-SAM-2/
python -m uvicorn api.app:app --host 0.0.0.0 --port 9001 --workers 4
检查 Grounded-SAM-2 的 API 服务是否健康:
#!/bin/bashurl="http://127.0.0.1:9001/api/v1.0/health"
http_code=$(curl -s -o /dev/null -w "%{http_code}" --max-time 5 --location "$url")
if [ $? -eq 0 ] && [ "$http_code" = "200" ]; thenecho 0
elseecho -1
fi
其他
测试效果:
MicroMamba 的 安装文件:
#!/bin/shset -eu# Detect the shell from which the script was called
parent=$(ps -o comm $PPID |tail -1)
parent=${parent#-} # remove the leading dash that login shells have
case "$parent" in# shells supported by `micromamba shell init`bash|fish|xonsh|zsh)shell=$parent;;*)# use the login shell (basename of $SHELL) as a fallbackshell=${SHELL##*/};;
esac# Parsing arguments
if [ -t 0 ] ; thenprintf "Micromamba binary folder? [~/.local/bin] "read BIN_FOLDERprintf "Init shell ($shell)? [Y/n] "read INIT_YESprintf "Configure conda-forge? [Y/n] "read CONDA_FORGE_YES
fi# Fallbacks
BIN_FOLDER="${BIN_FOLDER:-${HOME}/.local/bin}"
INIT_YES="${INIT_YES:-yes}"
CONDA_FORGE_YES="${CONDA_FORGE_YES:-yes}"# Prefix location is relevant only if we want to call `micromamba shell init`
case "$INIT_YES" iny|Y|yes)if [ -t 0 ]; thenprintf "Prefix location? [~/micromamba] "read PREFIX_LOCATIONfi;;
esac
PREFIX_LOCATION="${PREFIX_LOCATION:-${HOME}/micromamba}"# Computing artifact location
case "$(uname)" inLinux)PLATFORM="linux" ;;Darwin)PLATFORM="osx" ;;*NT*)PLATFORM="win" ;;
esacARCH="$(uname -m)"
case "$ARCH" inaarch64|ppc64le|arm64);; # pass*)ARCH="64" ;;
esaccase "$PLATFORM-$ARCH" inlinux-aarch64|linux-ppc64le|linux-64|osx-arm64|osx-64|win-64);; # pass*)echo "Failed to detect your OS" >&2exit 1;;
esacif [ "${VERSION:-}" = "" ]; thenRELEASE_URL="https://ghfast.top/https://github.com/mamba-org/micromamba-releases/releases/latest/download/micromamba-${PLATFORM}-${ARCH}"
elseRELEASE_URL="https://ghfast.top/https://github.com/mamba-org/micromamba-releases/releases/download/${VERSION}/micromamba-${PLATFORM}-${ARCH}"
fi# Downloading artifact
mkdir -p "${BIN_FOLDER}"
if hash curl >/dev/null 2>&1; thencurl "${RELEASE_URL}" -o "${BIN_FOLDER}/micromamba" -fsSL --compressed ${CURL_OPTS:-}
elif hash wget >/dev/null 2>&1; thenwget ${WGET_OPTS:-} -qO "${BIN_FOLDER}/micromamba" "${RELEASE_URL}"
elseecho "Neither curl nor wget was found" >&2exit 1
fi
chmod +x "${BIN_FOLDER}/micromamba"# Initializing shell
case "$INIT_YES" iny|Y|yes)case $("${BIN_FOLDER}/micromamba" --version) in1.*|0.*)shell_arg=-sprefix_arg=-p;;*)shell_arg=--shellprefix_arg=--root-prefix;;esac"${BIN_FOLDER}/micromamba" shell init $shell_arg "$shell" $prefix_arg "$PREFIX_LOCATION"echo "Please restart your shell to activate micromamba or run the following:\n"echo " source ~/.bashrc (or ~/.zshrc, ~/.xonshrc, ~/.config/fish/config.fish, ...)";;*)echo "You can initialize your shell later by running:"echo " micromamba shell init";;
esac# Initializing conda-forge
case "$CONDA_FORGE_YES" iny|Y|yes)"${BIN_FOLDER}/micromamba" config append channels conda-forge"${BIN_FOLDER}/micromamba" config append channels nodefaults"${BIN_FOLDER}/micromamba" config set channel_priority strict;;
esac