当前位置: 首页 > news >正文

SAM2Long本地部署,视频分割处理,绿幕抠像,超长视频支持

分享一个视频分割神器 SAM2Long。该项目可以一键抠出视频中的主体,快速输出绿幕视频或者是透明通道视频。SAM2Long是SAM2的优化版本,使得 SAM2Long 在处理遮挡、目标重现等长视频常见问题时表现得更加稳健。

本地部署

部署有两种方式,一种是源码部署,一种是使用我已经制作好的一键包。

源码部署

安装PyTorch环境

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

安装本地SAM2代码库

pip install -e .

下载模型检查点

Linux/macOS使用脚本下载

创建脚本文件:

mkdir -p SAM2Long/checkpoints
nano SAM2Long/checkpoints/download_ckpts.sh

download_ckpts.sh内容

#!/bin/bash# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.# Use either wget or curl to download the checkpoints
if command -v wget &> /dev/null; thenCMD="wget"
elif command -v curl &> /dev/null; thenCMD="curl -L -O"
elseecho "Please install wget or curl to download the checkpoints."exit 1
fi# Define the URLs for SAM 2 checkpoints
# SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
# sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
# sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
# sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
# sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"# Download each of the four checkpoints using wget
# echo "Downloading sam2_hiera_tiny.pt checkpoint..."
# $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }# echo "Downloading sam2_hiera_small.pt checkpoint..."
# $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }# echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
# $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }# echo "Downloading sam2_hiera_large.pt checkpoint..."
# $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }# Define the URLs for SAM 2.1 checkpoints
SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"# SAM 2.1 checkpoints
echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
$CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }echo "Downloading sam2.1_hiera_small.pt checkpoint..."
$CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
$CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }echo "Downloading sam2.1_hiera_large.pt checkpoint..."
$CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }echo "All checkpoints are downloaded successfully."

添加执行权限:

chmod +x SAM2Long/checkpoints/download_ckpts.sh

运行脚本:

cd SAM2Long/checkpoints && ./download_ckpts.sh
Windows脚本下载

创建批处理文件:

SAM2Long\checkpoints\download_ckpts.bat

bat脚本内容

@echo off
rem Copyright (c) Meta Platforms, Inc. and affiliates.
rem All rights reserved.rem This source code is licensed under the license found in the
rem LICENSE file in the root directory of this source tree.rem Check for wget or curl
set CMD=""
where wget >nul 2>&1
if %ERRORLEVEL% equ 0 (set CMD=wget -O
) else (where curl >nul 2>&1if %ERRORLEVEL% equ 0 (set CMD=curl -L -o) else (rem 如果既没有wget也没有curl,尝试使用PowerShell的Invoke-WebRequestecho Neither wget nor curl found. Trying PowerShell Invoke-WebRequest...set CMD=powershell)
)rem Define the URLs for SAM 2.1 checkpoints
set SAM2p1_BASE_URL=https://dl.fbaipublicfiles.com/segment_anything_2/092824
set sam2p1_hiera_t_url=%SAM2p1_BASE_URL%/sam2.1_hiera_tiny.pt
set sam2p1_hiera_s_url=%SAM2p1_BASE_URL%/sam2.1_hiera_small.pt
set sam2p1_hiera_b_plus_url=%SAM2p1_BASE_URL%/sam2.1_hiera_base_plus.pt
set sam2p1_hiera_l_url=%SAM2p1_BASE_URL%/sam2.1_hiera_large.ptrem SAM 2.1 checkpoints
echo Downloading sam2.1_hiera_tiny.pt checkpoint...
if "%CMD%"=="wget -O" (wget -O sam2.1_hiera_tiny.pt %sam2p1_hiera_t_url%
) else if "%CMD%"=="curl -L -o" (curl -L -o sam2.1_hiera_tiny.pt %sam2p1_hiera_t_url%
) else (powershell -Command "Invoke-WebRequest -Uri '%sam2p1_hiera_t_url%' -OutFile 'sam2.1_hiera_tiny.pt'"
)
if %ERRORLEVEL% neq 0 (echo Failed to download checkpoint from %sam2p1_hiera_t_url%exit /b 1
)echo Downloading sam2.1_hiera_small.pt checkpoint...
if "%CMD%"=="wget -O" (wget -O sam2.1_hiera_small.pt %sam2p1_hiera_s_url%
) else if "%CMD%"=="curl -L -o" (curl -L -o sam2.1_hiera_small.pt %sam2p1_hiera_s_url%
) else (powershell -Command "Invoke-WebRequest -Uri '%sam2p1_hiera_s_url%' -OutFile 'sam2.1_hiera_small.pt'"
)
if %ERRORLEVEL% neq 0 (echo Failed to download checkpoint from %sam2p1_hiera_s_url%exit /b 1
)echo Downloading sam2.1_hiera_base_plus.pt checkpoint...
if "%CMD%"=="wget -O" (wget -O sam2.1_hiera_base_plus.pt %sam2p1_hiera_b_plus_url%
) else if "%CMD%"=="curl -L -o" (curl -L -o sam2.1_hiera_base_plus.pt %sam2p1_hiera_b_plus_url%
) else (powershell -Command "Invoke-WebRequest -Uri '%sam2p1_hiera_b_plus_url%' -OutFile 'sam2.1_hiera_base_plus.pt'"
)
if %ERRORLEVEL% neq 0 (echo Failed to download checkpoint from %sam2p1_hiera_b_plus_url%exit /b 1
)echo Downloading sam2.1_hiera_large.pt checkpoint...
if "%CMD%"=="wget -O" (wget -O sam2.1_hiera_large.pt %sam2p1_hiera_l_url%
) else if "%CMD%"=="curl -L -o" (curl -L -o sam2.1_hiera_large.pt %sam2p1_hiera_l_url%
) else (powershell -Command "Invoke-WebRequest -Uri '%sam2p1_hiera_l_url%' -OutFile 'sam2.1_hiera_large.pt'"
)
if %ERRORLEVEL% neq 0 (echo Failed to download checkpoint from %sam2p1_hiera_l_url%exit /b 1
)echo All checkpoints are downloaded successfully. 

创建完成后,双击运行批处理文件。

手动下载

也可以直接访问以下四个链接下载对应模型

Tiny模型

https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt

Small模型

https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt

Base+模型

https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt

Large模型

https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

模型说明:包含四种大小的预训练模型(Tiny<Small<Base<Large),根据机器配置自行选择。

安装Gradio

pip install gradio

到这里安装就完成了,关于推理部分的实现,可以参考下方

基础调用(以DAVIS数据集为例)
python ./tools/vos_inference.py \--sam2_cfg configs/sam2.1/sam2.1_hiera_b+.yaml \   # 配置文件--sam2_checkpoint ./checkpoints/sam2.1_hiera_base_plus.pt \  # 模型文件--base_video_dir /你的视频文件夹路径 \  # 存放视频的文件夹--input_mask_dir /你的初始标记路径 \   # 第一帧的物体标记--output_mask_dir ./预测结果输出 \    # 结果保存位置--num_pathway 3 \      # 使用3条处理通道(平衡精度和速度)--iou_thre 0.1 \       # 过滤掉低质量识别(值越小要求越严)--uncertainty 2        # 识别稳定性设置(1-3之间调整)

初次尝试建议num_pathway=1快速测试,正式使用时设为3效果最佳

gradio界面源码:

import subprocess
import re
from typing import List, Tuple, Optional
import spaces# Define the command to be executed
command = ["python", "setup.py", "build_ext", "--inplace"]# Execute the command
result = subprocess.run(command, capture_output=True, text=True)css="""
div#component-18, div#component-25, div#component-35, div#component-41{align-items: stretch!important;
}
"""# Print the output and error (if any)
print("Output:\n", result.stdout)
print("Errors:\n", result.stderr)# Check if the command was successful
if result.returncode == 0:print("Command executed successfully.")
else:print("Command failed with return code:", result.returncode)import gradio as gr
from datetime import datetime
import os
os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from sam2.build_sam import build_sam2_video_predictorfrom moviepy.editor import ImageSequenceClipdef sparse_sampling(jpeg_images, original_fps, target_fps=6):# Calculate the frame interval for sampling based on the target fpsframe_interval = int(original_fps // target_fps)# Sparse sample the jpeg_images by selecting every 'frame_interval' framesampled_images = [jpeg_images[i] for i in range(0, len(jpeg_images), frame_interval)]return sampled_imagesdef get_video_fps(video_path):# Open the video filecap = cv2.VideoCapture(video_path)if not cap.isOpened():print("Error: Could not open video.")return None# Get the FPS of the videofps = cap.get(cv2.CAP_PROP_FPS)return fpsdef clear_points(image):# we clean allreturn [image,   # first_frame_path[],      # tracking_points[],      # trackings_input_labelimage,   # points_map#gr.State()     # stored_inference_state]def preprocess_video_in(video_path):# Generate a unique ID based on the current date and timeunique_id = datetime.now().strftime('%Y%m%d%H%M%S')# Set directory with this ID to store video frames extracted_frames_output_dir = f'frames_{unique_id}'# Create the output directoryos.makedirs(extracted_frames_output_dir, exist_ok=True)### Process video frames #### Open the video filecap = cv2.VideoCapture(video_path)if not cap.isOpened():print("Error: Could not open video.")return None# Get the frames per second (FPS) of the videofps = cap.get(cv2.CAP_PROP_FPS)# Calculate the number of frames to process (60 seconds of video)max_frames = int(fps * 60)frame_number = 0first_frame = Nonewhile True:ret, frame = cap.read()if not ret or frame_number >= max_frames:breakif frame_number % 6 == 0:# Format the frame filename as '00000.jpg'frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')# Save the frame as a JPEG filecv2.imwrite(frame_filename, frame)# Store the first frameif frame_number == 0:first_frame = frame_filenameframe_number += 1# Release the video capture objectcap.release()# scan all the JPEG frame names in this directoryscanned_frames = [p for p in os.listdir(extracted_frames_output_dir)if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]]scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))# print(f"SCANNED_FRAMES: {scanned_frames}")return [first_frame,           # first_frame_path[],          # tracking_points[],          # trackings_input_labelfirst_frame,           # input_first_frame_imagefirst_frame,           # points_mapextracted_frames_output_dir,            # video_frames_dirscanned_frames,        # scanned_framesNone,                  # stored_inference_stateNone,                  # stored_frame_namesgr.update(open=False)  # video_in_drawer]def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):print(f"You selected {evt.value} at {evt.index} from {evt.target}")tracking_points.append(evt.index)# tracking_points.value.append(evt.index)print(f"TRACKING POINT: {tracking_points}")if point_type == "include":trackings_input_label.append(1)# trackings_input_label.value.append(1)elif point_type == "exclude":trackings_input_label.append(0)# trackings_input_label.value.append(0)print(f"TRACKING INPUT LABEL: {trackings_input_label}")# Open the image and get its dimensionstransparent_background = Image.open(input_first_frame_image).convert('RGBA')w, h = transparent_background.size# Define the circle radius as a fraction of the smaller dimensionfraction = 0.02  # You can adjust this value as neededradius = int(fraction * min(w, h))# Create a transparent layer to draw ontransparent_layer = np.zeros((h, w, 4), dtype=np.uint8)for index, track in enumerate(tracking_points):if trackings_input_label[index] == 1:cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)else:cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)# Convert the transparent layer back to an imagetransparent_layer = Image.fromarray(transparent_layer, 'RGBA')selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)return tracking_points, trackings_input_label, selected_point_mapdef show_mask(mask, ax, obj_id=None, random_color=False):if random_color:color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)else:cmap = plt.get_cmap("tab10")cmap_idx = 0 if obj_id is None else obj_idcolor = np.array([*cmap(cmap_idx)[:3], 0.6])h, w = mask.shape[-2:]mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)ax.imshow(mask_image)def show_points(coords, labels, ax, marker_size=200):pos_points = coords[labels==1]neg_points = coords[labels==0]ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)def show_box(box, ax):x0, y0 = box[0], box[1]w, h = box[2] - box[0], box[3] - box[1]ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    def load_model(checkpoint):# Load model accordingly to user's choiceif checkpoint == "tiny":sam2_checkpoint = "./checkpoints/sam2.1_hiera_tiny.pt"model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"return [sam2_checkpoint, model_cfg]elif checkpoint == "samll":sam2_checkpoint = "./checkpoints/sam2.1_hiera_small.pt"model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"return [sam2_checkpoint, model_cfg]elif checkpoint == "base-plus":sam2_checkpoint = "./checkpoints/sam2.1_hiera_base_plus.pt"model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"return [sam2_checkpoint, model_cfg]# elif checkpoint == "large":#     sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"#     model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"#     return [sam2_checkpoint, model_cfg]def get_mask_sam_process(stored_inference_state,input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' functionscanned_frames, working_frame: str = None, # current frame being added pointsavailable_frames_to_check: List[str] = [],# progress=gr.Progress(track_tqdm=True)
):# get model and model config pathsprint(f"USER CHOSEN CHECKPOINT: {checkpoint}")sam2_checkpoint, model_cfg = load_model(checkpoint)print("MODEL LOADED")# set predictor predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')print("PREDICTOR READY")# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`# print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")video_dir = video_frames_dir# scan all the JPEG frame names in this directoryframe_names = scanned_frames# print(f"STORED INFERENCE STEP: {stored_inference_state}")if stored_inference_state is None:# Init SAM2 inference_stateinference_state = predictor.init_state(video_path=video_dir)inference_state['num_pathway'] = 3inference_state['iou_thre'] = 0.3inference_state['uncertainty'] = 2print("NEW INFERENCE_STATE INITIATED")else:inference_state = stored_inference_stateinference_state["device"] = 'cpu'# segment and track one object# predictor.reset_state(inference_state) # if any previous tracking, reset### HANDLING WORKING FRAME# new_working_frame = None# Add new pointif working_frame is None:ann_frame_idx = 0  # the frame index we interact with, 0 if it is the first frameworking_frame = "00000.jpg"else:# Use a regular expression to find the integermatch = re.search(r'frame_(\d+)', working_frame)if match:# Extract the integer from the matchframe_number = int(match.group(1))ann_frame_idx = frame_numberprint(f"NEW_WORKING_FRAME PATH: {working_frame}")ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)# Let's add a positive click at (x, y) = (210, 350) to get startedpoints = np.array(tracking_points, dtype=np.float32)# for labels, `1` means positive click and `0` means negative clicklabels = np.array(trackings_input_label, np.int32)_, out_obj_ids, out_mask_logits = predictor.add_new_points(inference_state=inference_state,frame_idx=ann_frame_idx,obj_id=ann_obj_id,points=points,labels=labels,)# Create the plotplt.figure(figsize=(12, 8))plt.title(f"frame {ann_frame_idx}")plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))show_points(points, labels, plt.gca())show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])# Save the plot as a JPG filefirst_frame_output_filename = "output_first_frame.jpg"plt.savefig(first_frame_output_filename, format='jpg')plt.close()# torch.cuda.empty_cache()# Assuming available_frames_to_check.value is a listif working_frame not in available_frames_to_check:available_frames_to_check.append(working_frame)print(available_frames_to_check)# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)@spaces.GPU
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):   # use bfloat16 for the entire notebooktorch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()if torch.cuda.get_device_properties(0).major >= 8:# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)torch.backends.cuda.matmul.allow_tf32 = Truetorch.backends.cudnn.allow_tf32 = True#### PROPAGATION ####sam2_checkpoint, model_cfg = load_model(checkpoint)# set predictor inference_state = stored_inference_stateif torch.cuda.is_available():inference_state["device"] = 'cuda'predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)else:inference_state["device"] = 'cpu'predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')frame_names = stored_frame_namesvideo_dir = video_frames_dir# Define a directory to save the JPEG imagesframes_output_dir = "frames_output_images"os.makedirs(frames_output_dir, exist_ok=True)# Initialize a list to store file paths of saved imagesjpeg_images = []# run propagation throughout the video and collect the results in a dictvideo_segments = {}  # video_segments contains the per-frame segmentation results# for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):#     video_segments[out_frame_idx] = {#         out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()#         for i, out_obj_id in enumerate(out_obj_ids)#     }out_obj_ids, out_mask_logits = predictor.propagate_in_video(inference_state, start_frame_idx=0, reverse=False,)print(out_obj_ids)for frame_idx in range(0, inference_state['num_frames']):video_segments[frame_idx] = {out_obj_ids[0]: (out_mask_logits[frame_idx]> 0.0).cpu().numpy()}# output_scores_per_object[object_id][frame_idx] = out_mask_logits[frame_idx].cpu().numpy()# render the segmentation results every few framesif vis_frame_type == "check":vis_frame_stride = 15elif vis_frame_type == "render":vis_frame_stride = 1plt.close("all")for out_frame_idx in range(0, len(frame_names), vis_frame_stride):plt.figure(figsize=(6, 4))plt.title(f"frame {out_frame_idx}")plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))for out_obj_id, out_mask in video_segments[out_frame_idx].items():show_mask(out_mask, plt.gca(), obj_id=out_obj_id)# Define the output filename and save the figure as a JPEG fileoutput_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")plt.savefig(output_filename, format='jpg')# Close the plotplt.close()# Append the file path to the listjpeg_images.append(output_filename)if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")torch.cuda.empty_cache()print(f"JPEG_IMAGES: {jpeg_images}")if vis_frame_type == "check":return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True)elif vis_frame_type == "render":# Create a video clip from the image sequenceoriginal_fps = get_video_fps(video_in)# sampled_images = sparse_sampling(jpeg_images, original_fps, target_fps=6)clip = ImageSequenceClip(jpeg_images, fps=original_fps//6)# clip = ImageSequenceClip(jpeg_images, fps=fps)# Write the result to a filefinal_vid_output_path = "output_video.mp4"# Write the result to a fileclip.write_videofile(final_vid_output_path,codec='libx264')return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)def update_ui(vis_frame_type):if vis_frame_type == "check":return gr.update(visible=True), gr.update(visible=False)elif vis_frame_type == "render":return gr.update(visible=False), gr.update(visible=True)def switch_working_frame(working_frame, scanned_frames, video_frames_dir):new_working_frame = Noneif working_frame == None:new_working_frame = os.path.join(video_frames_dir, scanned_frames[0])else:# Use a regular expression to find the integermatch = re.search(r'frame_(\d+)', working_frame)if match:# Extract the integer from the matchframe_number = int(match.group(1))ann_frame_idx = frame_numbernew_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])return gr.State([]), gr.State([]), new_working_frame, new_working_framedef reset_propagation(first_frame_path, predictor, stored_inference_state):predictor.reset_state(stored_inference_state)# print(f"RESET State: {stored_inference_state} ")return first_frame_path, [], [], gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)with gr.Blocks(css=css) as demo:first_frame_path = gr.State()tracking_points = gr.State([])trackings_input_label = gr.State([])video_frames_dir = gr.State()scanned_frames = gr.State()loaded_predictor = gr.State()stored_inference_state = gr.State()stored_frame_names = gr.State()available_frames_to_check = gr.State([])with gr.Column():gr.Markdown("""<h1 style="text-align: center;">🔥 SAM2Long Demo 🔥</h1>""")gr.Markdown("""This is a simple demo for video segmentation with [SAM2Long](https://github.com/Mark12Ding/SAM2Long).""")gr.Markdown("""### 📋 Instructions:It is largely built on the [SAM2-Video-Predictor](https://huggingface.co/spaces/fffiloni/SAM2-Video-Predictor).1. **Upload your video** [MP4-24fps]2. With **'include' point type** selected, click on the object to mask on the first frame3. Switch to **'exclude' point type** if you want to specify an area to avoid4. **Get Mask!**5. **Check Propagation** every 15 frames6. **Propagate with "render"** to render the final masked video7. **Hit Reset** button if you want to refresh and start again*Note: Input video will be processed for up to 60 seconds only for demo purposes.*""")with gr.Row():with gr.Column():with gr.Group():with gr.Group():with gr.Row():point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include", scale=2)clear_points_btn = gr.Button("Clear Points", scale=1)input_first_frame_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)                 points_map = gr.Image(label="Point n Click map", type="filepath",interactive=False)with gr.Group():with gr.Row():checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus"], value="tiny")submit_btn = gr.Button("Get Mask", size="lg")with gr.Accordion("Your video IN", open=True) as video_in_drawer:video_in = gr.Video(label="Video IN", format="mp4")gr.HTML("""<a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" /></a> to skip queue and avoid OOM errors from heavy public load""")with gr.Column():with gr.Group():# with gr.Group():# with gr.Row():working_frame = gr.Dropdown(label="working frame ID", choices=["frame_0.jpg"], value="frame_0.jpg", visible=False, allow_custom_value=False, interactive=True)#     change_current = gr.Button("change current", visible=False)# working_frame = []output_result = gr.Image(label="current working mask ref")with gr.Group():with gr.Row():vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)propagate_btn = gr.Button("Propagate", scale=2)reset_prpgt_brn = gr.Button("Reset", visible=False)output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)output_video = gr.Video(visible=False)# output_result_mask = gr.Image()# When new video is uploadedvideo_in.upload(fn = preprocess_video_in, inputs = [video_in], outputs = [first_frame_path, tracking_points, # update Tracking Points in the gr.State([]) objecttrackings_input_label, # update Tracking Labels in the gr.State([]) objectinput_first_frame_image, # hidden component used as ref when clearing pointspoints_map, # Image component where we add new tracking pointsvideo_frames_dir, # Array where frames from video_in are deep storedscanned_frames, # Scanned frames by SAM2stored_inference_state, # Sam2 inference statestored_frame_names, # video_in_drawer, # Accordion to hide uploaded video player],queue = False)# triggered when we click on image to add new pointspoints_map.select(fn = get_point, inputs = [point_type, # "include" or "exclude"tracking_points, # get tracking_points valuestrackings_input_label, # get tracking label valuesinput_first_frame_image, # gr.State() first frame path], outputs = [tracking_points, # updated with new pointstrackings_input_label, # updated with corresponding labelspoints_map, # updated image with points], queue = False)# Clear every points clicked and added to the mapclear_points_btn.click(fn = clear_points,inputs = input_first_frame_image, # we get the untouched hidden imageoutputs = [first_frame_path, tracking_points, trackings_input_label, points_map, #stored_inference_state,],queue=False)# change_current.click(#     fn = switch_working_frame,#     inputs = [working_frame, scanned_frames, video_frames_dir],#     outputs = [tracking_points, trackings_input_label, input_first_frame_image, points_map],#     queue=False# )submit_btn.click(fn = get_mask_sam_process,inputs = [stored_inference_state,input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame,available_frames_to_check,],outputs = [output_result, stored_frame_names, loaded_predictor,stored_inference_state,working_frame,],queue=False)reset_prpgt_brn.click(fn = reset_propagation,inputs = [first_frame_path, loaded_predictor, stored_inference_state],outputs = [points_map, tracking_points, trackings_input_label, output_propagated, stored_inference_state, output_result, available_frames_to_check, input_first_frame_image, working_frame, reset_prpgt_brn],queue=False)propagate_btn.click(fn = update_ui,inputs = [vis_frame_type],outputs = [output_propagated, output_video],queue=False).then(fn = propagate_to_all,inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn])demo.launch()

使用方法

这里介绍下使用方法(gradio界面版本)

上传需要处理的视频

然后在点类型这里选择处理方式。(include保留选中的部分,exclude保留选中以外的部分,一般选择默认的include)

然后增加关键点,点击图片中的部分即可。这里我想将人物抠出来。在人物脸部增加一个关键点

然后点击下方的获取遮罩,我们需要先预览下抠像效果。

可以看到右侧的预览结果,发现人物的衣服和头发并没有包含在内,需要再添加多个点来完善。

增加了两个点,这次再看下抠像效果。

可以看到这次人物的头发、身体以及衣服都被选中了。

然后在生成类型中选择render。输出类型根据你自己需求选择(greenscreen:输出绿幕视频。transparent:输出包含透明通道的mov视频)

选择完毕后点击生成耐心等待。

配置需求

WIN

WindowsN卡需8G显存

如果运行报错,需要安装cuda12.8

https://developer.nvidia.com/cuda-12-8-0-download-archive

支持50系显卡

MAC

暂不支持

关于速度

11秒视频(每秒24帧),tiny模型,合成绿幕视频

4090完成需要60秒左右。

5090完成需要45秒左右。

整合包获取

👇🏻👇🏻👇🏻下方下方下方👇🏻👇🏻👇🏻

夸夸夸盘:
https://pan.quark.cn/s/97f447f03235

度度度盘:
https://pan.baidu.com/s/1BFVfRCnnCdDOCpwDUTvcbQ?pwd=fbdf

制作不易,如果本文对您有帮助,还请点个免费的赞或在看!感谢您的阅读!

相关文章:

  • Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式
  • 深度解析云存储:概念、架构与应用实践
  • XMLGregorianCalendar跟Date、localDateTime以及String有什么区别
  • VisualXML全新升级 | 新增数据库编辑功能
  • 【深度学习新浪潮】大模型中,active parameters和total parameters都是什么?
  • rnn判断string中第一次出现a的下标
  • Unity-ECS详解
  • Ubuntu里面单独编译某一个模块
  • DriveGPT4: Interpretable End-to-end Autonomous Driving via Large Language Model
  • FreeRTOS学习02_任务管理
  • [论文阅读]人工智能 | CoMemo:给大视觉语言模型装个“图像记忆”
  • 一种使用 PowerToys 的键盘管理器工具重新映射按键实现在 Windows 上快捷输入字符的方式
  • 在Spring Boot中集成RabbitMQ的完整指南
  • iview组件库:使用Menu组件时手风琴属性accordion失效的解决方案
  • C#模式匹配深度解析与最佳实践
  • 【C#】多级缓存与多核CPU
  • 图解C#教程 第五版 第4章 类型、存储和变量 笔记
  • Ace网络验证软件卡密系统-免费免搭建 记录整理
  • 电脑插入多块移动硬盘后经常出现卡顿和蓝屏
  • Promise 链式调用:让异步编程更优雅
  • 太原建站模板源码/百度app打开
  • 网站建设方案编写人/站长工具站长之家官网
  • 做网站用空间好还是服务器好/网站排名优化外包公司
  • 学做效果图的网站有哪些/网络营销策略理论有哪些
  • 做网站编程要学什么/站长之家权重查询
  • 什么网站做家电测评/最近时事热点