当前位置: 首页 > news >正文

CBGSDataset类-带类别平衡采样的数据集封装器

class CBGSDataset(object):
    """带类别平衡采样的数据集封装器,实现论文《Class-balanced Grouping and Sampling for Point Cloud 3D Object Detection》
    (https://arxiv.org/abs/1908.09492) 提出的方法。

    通过类别平衡采样策略平衡不同类别场景的数量。

    参数:
        dataset (:obj:`CustomDataset`): 需要进行类别平衡采样的原始数据集。
    """

    def __init__(self, dataset):
        self.dataset = dataset
        self.CLASSES = dataset.CLASSES  # 获取原始数据集的类别列表
        self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}  # 构建类别名称到ID的映射
        self.sample_indices = self._get_sample_indices()  # 计算平衡采样后的索引
        
        # 如果原始数据集有flag属性(如训练/验证标记),则继承该属性
        if hasattr(self.dataset, 'flag'):
            self.flag = np.array(
                [self.dataset.flag[ind] for ind in self.sample_indices],
                dtype=np.uint8)

    def _get_sample_indices(self):
        """通过类别平衡策略生成采样索引列表。

        返回:
            list[int]: 平衡采样后的样本索引列表
        """
        # 初始化字典:记录每个类别对应的所有样本索引
        class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()}
        
        # 遍历数据集,统计每个类别的样本索引
        for idx in range(len(self.dataset)):
            sample_cat_ids = self.dataset.get_cat_ids(idx)  # 获取当前样本包含的类别ID
            for cat_id in sample_cat_ids:
                class_sample_idxs[cat_id].append(idx)
        
        # 计算总样本数(考虑一个样本可能属于多个类别)
        duplicated_samples = sum([len(v) for _, v in class_sample_idxs.items()])
        
        # 计算当前每个类别的分布比例
        class_distribution = {
            k: len(v) / duplicated_samples
            for k, v in class_sample_idxs.items()
        }

        sample_indices = []
        frac = 1.0 / len(self.CLASSES)  # 目标分布:均匀分布
        
        # 计算每个类别的采样比率
        ratios = [frac / v for v in class_distribution.values()]
        
        # 按比率进行随机采样
        for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios):
            sample_indices += np.random.choice(
                cls_inds, 
                int(len(cls_inds) * ratio),  # 按比率调整采样数量
                replace=True  # 允许重复采样(用于过采样)
            ).tolist()
            
        return sample_indices

核心功能说明

1. 问题背景
  • 在3D点云目标检测任务中(如自动驾驶场景),数据通常存在严重的类别不平衡问题(例如"汽车"样本远多于"行人")。

  • 直接训练会导致模型对高频类别过拟合,低频类别检测效果差。

2. 解决方案
  • 过采样(Oversampling):对稀有类别(如行人)的样本进行重复采样。

  • 欠采样(Undersampling):对高频类别(如汽车)的样本进行随机丢弃。

  • 最终使每个类别的样本贡献度相等

3. 算法关键步骤
  1. 统计原始分布

    • 遍历数据集,记录每个类别出现的所有样本索引。

  2. 计算平衡比率

    • 目标分布:若共有N个类别,则每个类别占比应为1/N

    • 对每个类别计算采样比率:ratio = (目标比例) / (当前比例)

  3. 执行重采样

    • 使用np.random.choice按比率随机选择样本,允许重复(replace=True)。


使用示例

假设原始数据分布:

  • 汽车:1000个样本

  • 行人:100个样本

  • 自行车:50个样本

经过CBGSDataset处理后:

  • 每个类别的目标比例:33.3%(3个类别)

  • 重采样后每个类别约383个样本(通过过采样/欠采样实现)


与FastBEV的关系

  • 在BEV(鸟瞰图)感知任务中,类别平衡能显著提升小物体检测效果(如行人、自行车)。

  • 可配合FastBEV的多相机特征融合模块使用,改善3D检测性能。

相关文章:

  • C++-FFmpeg-(5)-1-ffmpeg原理-ffmpeg编码接口-AVFrame-AVPacket-最简单demo
  • 有一个变量 在有些线程没有加锁 有些线程加锁了,那我在这些加锁的线程中能起到对应的作用吗
  • openEuler24.03 LTS下安装Spark
  • 使用 Google ML Kit 实现图片文字识别(提取美国驾照信息)
  • 爬虫抓包工具和PyExeJs模块
  • 领域大模型
  • flink iceberg写数据到hdfs,hive同步读取
  • 【C++游戏引擎开发】数学计算库GLM(线性代数)、CGAL(几何计算)的安装与使用指南
  • 【AI学习】AI Agent(人工智能体)
  • 蓝桥杯 C/C++ 组历届真题合集速刷(一)
  • GeoGPT:重新定义地理信息智能的下一代AI助手
  • 用PointNet++训练自己的数据集(语义分割模型semseg)
  • WEB安全--XSS--DOM破坏
  • 优选算法第八讲:链表
  • HOW - 如何测试 React 代码
  • unity urp 分层调酒思路解析
  • Nacos 服务发现的流程是怎样的?客户端如何获取最新的服务实例列表?
  • 鸿蒙开发_ARKTS快速入门_语法说明_渲染控制---纯血鸿蒙HarmonyOS5.0工作笔记012
  • 【JavaScript】十六、事件捕获和事件冒泡
  • TIM定时器
  • 临港新片区:发布再保险、国际航运、生物医药3个领域数据出境操作指引
  • 洞天寻隐·学林纪丨玉洞桃源:仇英青绿山水画中的洞天与身体
  • 苹果用户,安卓来“偷心”
  • 视觉周刊|劳动开创未来
  • 长线游、县域游、主题游等持续升温,假期文旅市场供需两旺
  • “五一”假期第四天,全社会跨区域人员流动量预计超2.7亿人次