当前位置: 首页 > news >正文

浅分析 PE3R 感知高效的三维重建

"近期,二维到三维感知技术的进步显著提升了对二维图像中三维场景的理解能力。然而,现有方法面临诸多关键挑战,包括跨场景泛化能力有限、感知精度欠佳以及重建速度缓慢。为克服这些局限,我们提出了感知高效三维重建框架(PE3R),旨在同时提升准确性与效率。PE3R采用前馈架构,实现了快速的三维语义场重建。该框架在多样化的场景与对象上展现出强大的零样本泛化能力,并显著提高了重建速度。在二维到三维开放词汇分割及三维重建上的大量实验验证了PE3R的有效性与多功能性。" PE3R的作者这样写

代码开源在 

GitHub - hujiecpp/PE3R: PE3R: Perception-Efficient 3D Reconstruction. Take 2 - 3 photos with your phone, upload them, wait a few minutes, and then start exploring your 3D world via text!PE3R: Perception-Efficient 3D Reconstruction. Take 2 - 3 photos with your phone, upload them, wait a few minutes, and then start exploring your 3D world via text! - hujiecpp/PE3Rhttps://github.com/hujiecpp/PE3R

论文地址

https://arxiv.org/abs/2503.07507https://t.co/ec3NSH0KoN

简单的梳理下论文背景和成果,后面会从代码分析模型结构

背景

PE3R 诞生的背景是现有方法如NeRF和3DGS依赖于场景特定的训练和语义提取,计算开销大,限制了实际应用的可扩展性。

研究空白

  • 现有方法在多场景泛化、感知精度和重建速度方面表现不佳。
  • 缺乏一种能够在不依赖3D数据的情况下高效进行3D语义重建的框架。

核心贡献

  • 提出了PE3R(Perception-Efficient 3D Reconstruction)框架,用于高效且准确的3D语义重建。
  • 通过仅使用2D图像实现3D场景重建,无需额外的3D数据(如相机参数或深度信息)。

技术架构

  • PE3R框架包含三个关键模块:像素嵌入消歧、语义场重建和全局视角感知。
  • 通过前馈机制实现快速的3D语义重建。

实现细节

  • 像素嵌入消歧模块通过跨视角、多层次的语义信息解决像素级别的语义歧义。
  • 语义场重建模块将语义信息直接嵌入到重建过程中,提升重建精度。
  • 全局视角感知模块通过全局语义对齐,减少单视角引入的噪声。

一个意想不到的细节

除了基本的 3D 重建,PE3R 还支持基于文本的查询功能,允许用户通过描述选择特定的 3D 对象,这在传统 3D 重建系统中并不常见。

结论

  • PE3R框架通过高效的3D语义重建,显著提升了2D到3D感知的速度和精度。
  • 该框架在不依赖场景特定训练或预校准3D数据的情况下,实现了零样本泛化,具有广泛的实际应用潜力。

下载代码安装必要依赖后,运行pe3r

上传官方测试用的 4 张测试图片

点击‘reconstruct’  后日志输出

渲染glb

查询Chair

尝试本地图片3D构建

构建结构

查找花盆

代码分析

.\PE3R\modules\pe3r\models.py 展示的模型结构

sys.path.append(os.path.abspath('./modules/ultralytics'))

from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
from modules.mast3r.model import AsymmetricMASt3R

# from modules.sam2.build_sam import build_sam2_video_predictor
from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
from modules.mobilesamv2 import sam_model_registry

from sam2.sam2_video_predictor import SAM2VideoPredictor

class Models:
    def __init__(self, device):
        # -- mast3r --
        # MAST3R_CKP = './checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
        MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
        self.mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)

        # -- sam2 --
        self.sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)

        # -- mobilesamv2 & sam1 --
        SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
        self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)
        # image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)
        sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
        image_encoder = sam1.vision_encoder

        prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
        self.mobilesamv2.prompt_encoder = prompt_encoder
        self.mobilesamv2.mask_decoder = mask_decoder
        self.mobilesamv2.image_encoder=image_encoder
        self.mobilesamv2.to(device=device)
        self.mobilesamv2.eval()

        # -- yolov8 --
        YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
        self.yolov8 = ObjectAwareModel(YOLO8_CKP)

        # -- siglip --
        self.siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
        self.siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256")
        self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")

模型结构中的 Models 类初始化了多个组件,如 MASt3R、SAM2、MobileSAMv2、SAM1、YOLOv8 和 Siglip,共同处理复杂的场景。MASt3R是从checkpoint加载的非对称模型,用于3D重建或匹配。YOLOv8用于对象检测,Siglip则用于图像和文本特征提取。看来这个类整合了多个尖端模型,分别处理不同任务。

Models 类初始化了以下关键组件:

  • MASt3R:用于多视图立体视觉,估计图像对之间的深度和姿态。
  • 加载checkpoint(如 naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric),是一个多视图立体视觉模型。
  • 功能:估计图像对之间的深度和姿态,核心用于 3D 重建。
  • SAM2:视频分割模型,用于在图像序列中传播分割掩码。
  • 从预训练模型 facebook/sam2.1-hiera-large 加载,是视频分割预测器。
  • 功能:在图像序列中传播分割掩码,确保对象在不同帧中的一致性。
  • MobileSAMv2 和 SAM1:用于基于对象检测的精确图像分割。
  • MobileSAMv2 使用自定义掩码解码器(如 Prompt_guided_Mask_Decoder.pt)初始化,结合 SAM1 的视觉编码器。
  • 功能:基于 YOLOv8 的对象检测结果,进行精确的图像分割。
  • YOLOv8:对象检测模型,识别图像中的潜在对象。
  • checkpoint ObjectAwareModel.pt 加载,是 Ultralytics 的对象检测模型。
  • 功能:识别图像中的潜在对象,提供边界框供后续分割使用。
  • Siglip:从图像片段提取特征,支持基于文本的查询。
  • 从 google/siglip-large-patch16-256 加载,支持图像和文本特征提取。
  • 功能:从分割后的图像片段提取特征,支持基于文本的查询。

对比表:组件功能与作用

组件主要功能在 3D 重建中的作用
MASt3R多视图立体视觉估计深度和姿态,核心重建步骤
SAM2视频分割传播确保跨视图分割一致性
MobileSAMv2图像分割基于检测生成精确掩码
SAM1图像分割辅助提供初始帧的分割掩码
YOLOv8对象检测提供边界框,启动分割流程
Siglip特征提取支持对象分组和文本查询

技术细节与优势

  • 分割的准确性:结合 YOLOv8、SAM1 和 SAM2,确保对象在不同视图中的一致分割。
  • 特征的鲁棒性:Siglip 的特征提取支持跨视图的对象分组,SLERP 处理重叠掩码增强一致性。
  • 全局优化的复杂性:使用图优化技术(如最小生成树初始化)确保 3D 点的准确对齐。

demo.py 中的 3D 重建流程

demo.py 中的工作流程利用这些组件进行图像分割、特征提取和全局对齐,从而生成高质量的 3D 模型。demo.py文件使用Models类从一组图像进行3D重建

get_reconstructed_scene函数是核心过程。
流程包括使用YOLOv8检测图像中的对象,然后用SAM1和SAM2进行分割。MASt3R用于多视图立体重建,获取深度和姿态。
get_cog_feats函数使用SAM2初始化状态,并通过视频传播分割掩码。每个帧的掩码被裁剪、调整大小后通过Siglip提取特征。3D重建的魔力在于结合对象检测、分割、特征提取和多视图立体技术。全局优化确保所有视图一致对齐。

    • 图像加载与准备
      • 使用 Images 类加载输入图像列表,准备用于后续处理。
      • 如果图像少于 2 张,抛出错误,确保有足够视图进行重建。
    • 对象检测与分割
      • YOLOv8 检测:对图像运行 YOLOv8,获取对象边界框,设置置信度阈值为 0.25,IOU 阈值为 0.95。
      • SAM1 分割:基于 YOLOv8 的边界框,使用 MobileSAMv2 和自定义掩码解码器生成精确的分割掩码。
      • SAM2 传播:初始化 SAM2 状态,使用第一帧的 SAM1 掩码,之后通过 propagate_in_video 在序列中传播掩码。
      • NMS 过滤:使用非最大抑制(NMS)过滤重叠掩码,确保分割结果的唯一性。
    • 特征提取
      • 在 get_cog_feats 函数中,对每个帧的每个分割掩码:
        • 裁剪对应区域,填充为正方形,调整大小为 256x256。
        • 使用 Siglip 处理这些图像片段,提取特征向量。
      • 对于重叠的掩码,进行球面线性插值(SLERP)以合并特征,确保特征的一致性。
      • 最终生成 multi_view_clip_feats,每个对象 ID 对应一个特征向量,跨视图平均。
    • 多视图立体视觉
      • 使用 make_pairs 函数根据场景图类型(complete、swin 或 oneref)生成图像对。
      • 运行 MASt3R 推理,估计每对图像的深度和姿态,输出匹配和深度图。
    • 全局对齐
      • 使用 global_aligner 优化相机姿态和 3D 点:
        • 模式为 PointCloudOptimizer(多于 2 张图像)或 PairViewer(2 张图像)。
        • 利用分割图(cog_seg_maps 和 rev_cog_seg_maps)和特征(cog_feats)指导对齐。
      • 优化过程包括多次迭代(默认 300 次),使用线性或余弦调度调整学习率。
    • 3D 模型生成
      • 使用 get_3D_model_from_scene 将对齐后的点云或网格导出为 GLB 文件。
      • 支持选项如点云显示、天空掩码、深度清理和相机透明度。

    get_reconstructed_scene 方法分析

    其中 get_reconstructed_scene 函数展示了 3D 重建的详细步骤:

    def get_reconstructed_scene(outdir, pe3r, device, silent, filelist, schedule, niter, min_conf_thr,
                                as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
                                scenegraph_type, winsize, refid):
        """
        from a list of images, run dust3r inference, global aligner.
        then run get_3D_model_from_scene
        """
        if len(filelist) < 2:
            raise gradio.Error("Please input at least 2 images.")
    
        images = Images(filelist=filelist, device=device)
        
        # try:
        cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, pe3r)
        imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
        # except Exception as e:
        #     rev_cog_seg_maps = []
        #     for tmp_img in images.np_images:
        #         rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)
        #         rev_cog_seg_maps.append(rev_seg_map)
        #     cog_seg_maps = rev_cog_seg_maps
        #     cog_feats = torch.zeros((1, 1024))
        #     imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
    
        if len(imgs) == 1:
            imgs = [imgs[0], copy.deepcopy(imgs[0])]
            imgs[1]['idx'] = 1
    
        if scenegraph_type == "swin":
            scenegraph_type = scenegraph_type + "-" + str(winsize)
        elif scenegraph_type == "oneref":
            scenegraph_type = scenegraph_type + "-" + str(refid)
    
        pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
        output = inference(pairs, pe3r.mast3r, device, batch_size=1, verbose=not silent)
        mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
        scene_1 = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)
        lr = 0.01
        # if mode == GlobalAlignerMode.PointCloudOptimizer:
        loss = scene_1.compute_global_alignment(tune_flg=True, init='mst', niter=niter, schedule=schedule, lr=lr)
    
        try:
            import torchvision.transforms as tvf
            ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            for i in range(len(imgs)):
                # print(imgs[i]['img'].shape, scene.imgs[i].shape, ImgNorm(scene.imgs[i])[None])
                imgs[i]['img'] = ImgNorm(scene_1.imgs[i])[None]
            pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
            output = inference(pairs, pe3r.mast3r, device, batch_size=1, verbose=not silent)
            mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
            scene = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)
            ori_imgs = scene.ori_imgs
            lr = 0.01
            # if mode == GlobalAlignerMode.PointCloudOptimizer:
            loss = scene.compute_global_alignment(tune_flg=False, init='mst', niter=niter, schedule=schedule, lr=lr)
        except Exception as e:
            scene = scene_1
            scene.imgs = ori_imgs
            scene.ori_imgs = ori_imgs
            print(e)
    
    
        outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
                                          clean_depth, transparent_cams, cam_size)
    
        # also return rgb, depth and confidence imgs
        # depth is normalized with the max value for all images
        # we apply the jet colormap on the confidence maps
        rgbimg = scene.imgs
        depths = to_numpy(scene.get_depthmaps())
        confs = to_numpy([c for c in scene.im_conf])
        # confs = to_numpy([c for c in scene.conf_2])
        cmap = pl.get_cmap('jet')
        depths_max = max([d.max() for d in depths])
        depths = [d / depths_max for d in depths]
        confs_max = max([d.max() for d in confs])
        confs = [cmap(d / confs_max) for d in confs]
    
        imgs = []
        for i in range(len(rgbimg)):
            imgs.append(rgbimg[i])
            imgs.append(rgb(depths[i]))
            imgs.append(rgb(confs[i]))
    
        return scene, outfile, imgs

    函数输入与初始检查

    get_reconstructed_scene 函数接受多个参数,包括输出目录、模型实例、设备、静默模式、图像文件列表、调度方式、迭代次数、最小置信度阈值等。至少包含 2 张图像确保了有足够的多视图信息进行重建。

    图像加载与准备

    函数创建 Images 对象,加载输入图像,并调用 get_cog_feats 获取认知分割图(cog_seg_maps)、反向认知分割图(rev_cog_seg_maps)和认知特征(cog_feats)。这些步骤涉及对象检测和分割:

    • 对象检测:使用 YOLOv8 检测图像中的对象,提供边界框,置信度阈值为 0.25,IOU 阈值为 0.95。
    • 图像分割:通过 MobileSAMv2 和 SAM1 生成精确的分割掩码,结合 SAM2 在序列中传播这些掩码,确保跨视图的一致性。
    • 特征提取:对每个分割区域裁剪、填充为正方形(256x256),使用 Siglip 提取特征向量。对于重叠掩码,通过球面线性插值(SLERP)合并特征。

    如果 get_cog_feats 失败,函数会回退到默认分割图(全为 -1)和零特征向量,确保流程继续。

    单图像处理

    如果输入仅有一张图像,函数会复制该图像生成两张,确保可以进行对齐和重建。这是为了处理边缘情况,维持多视图立体视觉的必要性。

    场景图配置

    根据 scenegraph_type(complete、swin 或 oneref),函数调整场景图参数:

    • 如果为 “swin”,追加窗口大小;如果为 “oneref”,追加参考 ID。这些参数影响后续图像对的生成。

    图像对生成与 MASt3R 推理

    使用 make_pairs 函数根据场景图生成图像对,参数包括 scene_graph 类型、对称化(symmetrize=True)等。然后调用 inference 函数,使用 pe3r.mast3r(MASt3R 模型)估计每对图像的深度和姿态:

    • MASt3R 是一个多视图立体视觉模型,核心功能是生成 3D 点云和相机姿态。
    • 推理过程批次大小为 1,是否显示详细信息由 silent 控制。

    全局对齐优化

    函数使用 global_aligner 进行全局对齐,模式根据图像数量选择:

    • 如果图像多于 2 张,使用 PointCloudOptimizer 模式;否则使用 PairViewer 模式。
    • 对齐过程利用 cog_seg_maps、rev_cog_seg_maps 和 cog_feats,这些分割和特征信息指导 3D 点的分组和优化。
    • 调用 compute_global_alignment 进行优化,初始化为最小生成树(init='mst'),迭代次数为 niter(默认 300),学习率 lr=0.01,调度方式为 schedule(线性或余弦)。

    二次推理与对齐

    函数尝试二次优化:

    • 导入 torchvision 进行图像归一化(均值为 0.5,标准差为 0.5)。
    • 更新 imgs 中的图像数据,重复图像对生成、MASt3R 推理和全局对齐步骤。
    • 如果失败,回退到第一次对齐结果,保持 scene_1 并更新图像数据。

    3D 模型生成

    调用 get_3D_model_from_scene 生成最终 3D 模型:

    • 提取场景中的 RGB 图像、3D 点、掩码、焦距和相机姿态。
    • 支持后处理选项,如清理点云(clean_depth)、掩盖天空(mask_sky)。
    • 将结果导出为 GLB 文件,支持点云(as_pointcloud=True)或网格显示,相机大小由 cam_size 控制。

    返回结果

    函数返回场景对象、输出 GLB 文件路径,以及一组图像数组:

    • 包括 RGB 图像、归一化深度图(以最大值归一化)和置信度图(使用 jet 颜色映射)。
    • 这些图像用于可视化,深度和置信度图帮助用户评估重建质量。

    技术细节与优势

    • 分割的准确性:YOLOv8 提供初始检测,MobileSAMv2 和 SAM2 确保精确且一致的分割,减少背景噪声。
    • 特征的鲁棒性:Siglip 提取的特征支持跨视图对象分组,SLERP 处理重叠掩码增强一致性。
    • 全局优化的复杂性:使用图优化技术(如最小生成树初始化)确保 3D 点的准确对齐,迭代优化提升精度。

    下面分析就不一一展示代码实现了,有兴趣的可以直接下载代码对照 demo.py列举的分析查看

    除了 get_reconstructed_scene,其他方法如 _convert_scene_output_to_glb、get_3D_model_from_scene 等在后处理、分割、特征提取和用户交互中扮演关键角色。

    其他方法分析

    mask_to_box

    • 功能:从掩码生成边界框(左、上、右、下)。
    • 作用
      • 将分割掩码转换为边界框格式,便于后续裁剪和特征提取。
    • 关键步骤
      • 计算掩码中非零值的边界,生成 [left, top, right, bottom]。
      • 如果掩码为空,返回零边界框。
    • 为什么重要:为图像裁剪提供定位信息。

    pad_img

    • 功能:将图像填充为正方形,保持宽高比。
    • 作用
      • 标准化图像尺寸,适配 Siglip 的输入要求(256x256)。
    • 关键步骤
      • 创建最大边长的零矩阵,将图像居中填充。
    • 为什么重要:确保特征提取输入一致。

    get_cog_feats

    • 功能:提取图像序列的分割图和特征。
    • 作用
      • 生成认知分割图(cog_seg_maps)、反向分割图(rev_cog_seg_maps)和多视图特征(cog_feats)。
    • 关键步骤
      • 使用 SAM2 传播掩码,结合 SAM1 添加新掩码。
      • 对每帧分割区域裁剪、填充,提取 Siglip 特征。
      • 使用 SLERP 合并重叠特征,生成多视图特征。
    • 为什么重要:为全局对齐提供对象级信息。

    set_scenegraph_options

    • 功能:根据场景图类型调整 UI 参数。
    • 作用
      • 配置滑动窗口(swin)或参考帧(oneref)的参数。
    • 关键步骤
      • 根据图像数量动态设置窗口大小和参考 ID。
    • 为什么重要:优化图像对生成策略。

    get_mask_from_img_sam1

    • 功能:使用 YOLOv8 和 MobileSAMv2 从图像生成分割掩码。
    • 作用
      • 提供初始帧的精确分割结果。
    • 关键步骤
      • YOLOv8 检测边界框,MobileSAMv2 生成掩码。
      • 分批处理(每批 320 个),过滤小面积掩码(<0.2%),应用 NMS。
    • 为什么重要:为 SAM2 提供初始掩码,支持跨帧传播。

    get_seg_img

    • 功能:根据掩码和边界框从图像中裁剪分割区域。
    • 作用
      • 提取特定对象的图像片段,用于特征提取。
    • 关键步骤
      • 根据掩码面积与边界框面积的比率,决定背景填充方式(黑色或随机噪声)。
      • 裁剪图像,返回分割区域。
    • 为什么重要:为 Siglip 特征提取准备输入。

    mask_nms

    • 功能:对一组分割掩码执行非最大抑制(NMS),去除重叠掩码。
    • 作用
      • 确保每个对象只保留一个主要掩码,避免重复分割。
    • 关键步骤
      • 计算掩码之间的交集占比(IOU),如果超过阈值(默认 0.8),抑制较小的掩码。
      • 返回保留的掩码索引列表。
    • 为什么重要:在对象分割中防止冗余,提高后续特征提取和对齐的准确性。
    • 示例:从椅子照片中检测到多个重叠掩码,mask_nms 保留最大的一个。

     _convert_scene_output_to_glb

    • 功能:将重建的场景数据(RGB 图像、3D 点、掩码、焦距和相机姿态)转换为 GLB 文件格式,用于 3D 模型的导出和可视化。
    • 作用
      • 将 3D 点云或网格与颜色信息结合,生成可视化的 3D 模型。
      • 添加相机位置和方向,便于理解拍摄视角。
      • 支持点云(as_pointcloud=True)或网格(as_pointcloud=False)两种表示方式。
    • 关键步骤
      • 如果选择点云模式,合并所有点的坐标和颜色,创建 trimesh.PointCloud。
      • 如果选择网格模式,逐帧生成网格并合并为单一 trimesh.Trimesh。
      • 使用 trimesh.Scene 添加相机,应用坐标变换(绕 Y 轴旋转 180 度),导出为 GLB 文件。

    这是 3D 重建的最终输出步骤,将内部数据结构转换为用户可交互的格式。

    综上这些方法共同支持 PE3R 的 3D 重建流程:

    • 分割与特征提取:get_mask_from_img_sam1、get_cog_feats 等提供对象级信息。
    • 模型生成:_convert_scene_output_to_glb、get_3D_model_from_scene 完成 3D 输出。

    上面代码代码种涉及的两个模型,作者只提供了pt文件并没有提供训练方式

    下载地址在 https://github.com/hujiecpp/PE3R/releases/tag/checkpoints.

    再多聊几句这个 PromptModelPredictor(DetectionPredictor)

    PromptModelPredictor 类是基于 Ultralytics YOLO 框架实现的自定义检测预测器,定义在 PE3R 项目的某个模块中(可能与 ObjectAwareModel 相关)。它通过继承 DetectionPredictor,并重写 __init__、adjust_bboxes_to_image_border 和 postprocess 方法,实现了特定的对象检测和边界框处理功能。

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = 'segment'

    初始化预测器,设置任务类型为“分割”。后续检测任务奠定基础,确保与 YOLO 框架兼容。

    def adjust_bboxes_to_image_border(self, boxes, image_shape, threshold=20):    
        h, w = image_shape
        boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor(0, dtype=torch.float, device=boxes.device), boxes[:, 0])  # x1
        boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor(0, dtype=torch.float, device=boxes.device), boxes[:, 1])  # y1
        boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor(w, dtype=torch.float, device=boxes.device), boxes[:, 2])  # x2
        boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor(h, dtype=torch.float, device=boxes.device), boxes[:, 3])  # y2
        return boxes
    • 调整边界框坐标,确保其不超出图像边界并避免过于靠近边缘。
    • 作用:将靠近图像边缘(小于 threshold=20 像素)的坐标设置为边界值(0 或图像宽高)。
      • boxes[:, 0](x1):如果小于 20,设为 0。
      • boxes[:, 1](y1):如果小于 20,设为 0。
      • boxes[:, 2](x2):如果大于宽度-20,设为宽度。
      • boxes[:, 3](y2):如果大于高度-20,设为高度。
    • 为什么重要:防止边界框超出图像范围或过于贴近边缘,确保后续分割或特征提取的有效性。
    def postprocess(self, preds, img, orig_imgs):
        p = ops.non_max_suppression(preds[0], self.args.conf, self.args.iou, agnostic=self.args.agnostic_nms, max_det=self.args.max_det, nc=len(self.model.names), classes=self.args.classes)
        results = []
        if len(p) == 0 or len(p[0]) == 0:
            print("No object detected.")
            return results
        full_box = torch.zeros_like(p[0][0])
        full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
        full_box = full_box.view(1, -1)
        self.adjust_bboxes_to_image_border(p[0][:, :4], img.shape[2:]) 
        for i, pred in enumerate(p):
            orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
            path = self.batch[0]
            img_path = path[i] if isinstance(path, list) else path
            if not len(pred): 
                results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
                continue
            if self.args.retina_masks:
                if not isinstance(orig_imgs, torch.Tensor):
                    pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            else:
                if not isinstance(orig_imgs, torch.Tensor):
                    pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            results.append(
                Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=torch.zeros_like(img)))
        return results
    • 对 YOLO 模型的预测结果进行后处理,返回检测结果。
    • 作用
      • 应用非最大抑制(NMS)过滤重叠边界框。
      • 调整边界框坐标,适配原始图像尺寸。
      • 返回 Results 对象,包含边界框信息(未生成实际掩码)。
    • 关键步骤
      • NMS:使用 ops.non_max_suppression 过滤预测结果,基于置信度(conf)、IOU(iou)和最大检测数(max_det)。
      • 边界检查:如果没有检测到对象,返回空结果并打印提示。
      • 边界框调整:调用 adjust_bboxes_to_image_border 修正坐标。
      • 坐标缩放:使用 ops.scale_boxes 将边界框从输入图像尺寸缩放到原始图像尺寸。
      • 结果封装:创建 Results 对象,包含图像、路径、类别名称和边界框(boxes),掩码默认为零。
    • 将原始检测结果转换为标准格式,为后续 SAM 分割提供输入。

    PE3R 的上下文中,PromptModelPredictor 可能被实例化为 Models 类中的 yolov8 组件(ObjectAwareModel)。其作用包括:

    • 提供初始检测:为 get_mask_from_img_sam1 提供边界框,启动精确分割流程。
    • 支持多视图一致性:通过检测对象位置,帮助 SAM2 在图像序列中传播掩码。
    • 集成到 3D 重建:边界框信息间接支持特征提取和全局对齐。

    http://www.dtcms.com/a/80703.html

    相关文章:

  1. LeetCode[242]有效的字母异位词
  2. 【Linux】Windows 客户端访问 Linux 服务器
  3. 解释什么是受控组件和非受控组件
  4. VSTO(C#)Excel开发11:自定义任务窗格与多个工作簿
  5. Chapter 4-15. Troubleshooting Congestion in Fibre Channel Fabrics
  6. 游戏盾是什么?如何为在线游戏保驾护航?
  7. 【Qt】QWidget属性2
  8. FastAPI WebSocket 无法获取真实 IP 错误记录
  9. Redis 跳表原理详解
  10. CSV文件格式
  11. 深度学习中的“刹车”:正则化如何防止模型“超速”
  12. 用Promise实现ajax的自动重试
  13. 【uniapp】记录tabBar不显示踩坑记录
  14. 大数据学习(75)-大数据组件总结
  15. S32K144外设实验(三):ADC单通道连续采样(中断)
  16. Android第五次面试总结(网络篇)
  17. Linux上位机开发实战(camera视频读取)
  18. 【DeepSeek 学C+】effective modern c+ 条款七 初始化
  19. 【c++】【STL】unordered_set 底层实现(简略版)
  20. k8s 配置imagePullSecrets仓库认证
  21. SpringMVC全局异常处理机制
  22. Android SDK下载安装配置
  23. 多无人车协同探索开源包启动文件介绍(下)
  24. 【FAQ】HarmonyOS SDK 闭源开放能力 —Push Kit(10)
  25. LVGL和其他图形库区别于联系
  26. Spring Boot Actuator 自定义健康检查(附Demo)
  27. AI安全、大模型安全研究(DeepSeek)
  28. 3. 轴指令(omron 机器自动化控制器)——>MC_SetPosition
  29. Python(数据结构概念,算法时间效率衡量,链表)
  30. Oracle GoldenGate (OGG) 安装、使用及常见故障处理