当前位置: 首页 > news >正文

【SAM2代码解析】数据集处理3--混合数据加载器(DataLoader)

在这里插入图片描述

前情提要—trainer

展示了在训练过程中,数据是如何流动的

1)trainer的初始化

trainer = instantiate(cfg.trainer, _recursive_=False)
  • 传入的参数:
    在这里插入图片描述

    • data:training.dataset.sam2_datasets.TorchTrainMixedDataset
      在这里插入图片描述
    • model: training.model.sam2.SAM2Train
    • checkpoint:…
    • mode: train_only
    • optim: torch.optim.AdamW
    • loss: training.loss_fns.MultiStepMultiMasksAndIous
  • 配置信息赋值给实例变量
    在这里插入图片描述

  • 初始化其他配置信息…

  • self._setup_dataloaders() 设置数据加载器

    • 根据mode的格式实例化train_dataset,如下图所示(接7.1-1)的初始化模块):
      在这里插入图片描述

2)trainer.run

dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))-----接7.1-2)的get_loader方法

7. sam2_dataset.py

7.1 MixedDataLoader 类​

1)初始化

  • 传入参数
    在这里插入图片描述
    这里的传入参数就是前面trainer初始化中,data的配置信息。
  • 初始化信息
    • 属性赋值
    • 设置数据集的周期(??没看懂这样的意义)
    • sam允许训练时使用多个数据集
      • 计算每个数据集的采样概率,概率为子数据集量/全部数据集量,若只有一个数据集,那么采样概率为1

2)get_loader

    def get_loader(self, epoch) -> Iterable:# 初始化数据加载器列表dataloaders = []# 遍历数据集和批次大小for d_idx, (dataset, batch_size) in enumerate(zip(self.datasets, self.batch_sizes)):# 处理每个数据集# 如果每个周期的阶段数 self.phases_per_epoch 大于 1,则处理数据集的分块和设置周期。if self.phases_per_epoch > 1:# Major epoch that looops over entire dataset# len(main_epoch) == phases_per_epoch * len(epoch)# 计算主周期和局部阶段main_epoch = epoch // self.phases_per_epoch# Phase with in the main epochlocal_phase = epoch % self.phases_per_epoch# Start of new data-epoch or job is resumed after preemtion.if local_phase == 0 or self.chunks[d_idx] is None:# set seed for dataset epoch# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.self._set_dataset_epoch(dataset, main_epoch)# Separate random generator for subset samplingg = torch.Generator()g.manual_seed(main_epoch)self.chunks[d_idx] = torch.chunk(torch.randperm(len(dataset), generator=g),self.phases_per_epoch,)dataset = Subset(dataset, self.chunks[d_idx][local_phase])# 如果是新的数据周期或工作在中断后恢复,则设置数据集的周期,并为数据集创建随机分块。else:self._set_dataset_epoch(dataset, epoch)# 创建DistributedSampler采样器,用于在分布式环境中对数据集进行采样sampler = DistributedSampler(dataset, shuffle=self.shuffle)sampler.set_epoch(epoch)# 创建 BatchSampler 对象,用于从采样器中按批次大小采样数据。batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)# 创建数据加载器dataloaders.append(DataLoader(dataset,num_workers=self.num_workers,pin_memory=self.pin_memory,batch_sampler=batch_sampler,collate_fn=self.collate_fn,worker_init_fn=self.worker_init_fn,))# 返回混合数据加载器return MixedDataLoader(dataloaders, self.dataset_prob)

相关文章:

  • 集成产品开发(IPD)核心框架:阶段门流程设计与跨部门协同实施要点
  • Linux 检查口令策略设置是否符合复杂度要求
  • kubernetes》》k8s》》Service 、Ingress 区别
  • Vue+tdesign t-input-number 设置长度和显示X号
  • [论文精读]Agent综述—— A survey on large language model based autonomous agents
  • Sigmoid函数简介及其Python实现
  • SQL命令二:SQL 高级查询与特殊算法
  • 《JDK 1.7 vs JDK 1.8 ConcurrentHashMap 深度对比与实战解析》
  • EWM 流程全自动化实现方法
  • MySQL explain
  • 《可信数据空间 技术架构》技术文件正式发布
  • Gas 优化不足、升级机制缺陷问题
  • 【区块链】区块链技术介绍
  • 『深夜_MySQL』详解数据库 探索数据库是如何存储的
  • MySQL 中的索引数量是否越多越好?为什么?
  • 华为发布全球首个L3商用智驾ADS4.0
  • vue+django农产品价格预测和推荐可视化系统[带知识图谱]
  • DeepSeek最新大模型发布-DeepSeek-Prover-V2-671B
  • harmonyOS 手机,双折叠,平板,PC端屏幕适配
  • 分布式链路ID实现
  • 解放日报:人形机器人新赛道正积蓄澎湃动能
  • 周劼已任中国航天科技集团有限公司董事、总经理、党组副书记
  • 国务院任免国家工作人员:颜清辉任人社部副部长
  • 上海科创再出发:“造星”的城和“摘星”的人
  • 国家能源局通报上月投诉情况:赤峰有群众反映电费异常增高,已退费
  • 华夏幸福:去年营业收入237.65亿元,同比减亏12亿元