当前位置: 首页 > news >正文

【深度学习】自定义实现DataSet和DataLoader

dataset数据集

作用:

  • 存储数据集的信息
  • 获取数据集长度 __len__
  • 获取数据集某特定条目的内容 __getitem__

dataloader 数据加载器

作用:

  • 从数据集中随机加载数据, 并拼接为一个 batch
  • 实现迭代器, 可以使用时, 迭代获取数据内容

代码实现:

import numpy as np
class ImageDataset():
    def __init__(self, raw_data):
        """
        数据集初始化
        """
        self.raw_data = raw_data
    
    def __len__(self):
        """
        返回数据集的长度
        """
        return len(self.raw_data)
    
    def __getitem__(self, index):
        """
        根据索引获取数据集中某一条数据
        """
        image, label = self.raw_data[index]
        return image, label

class  DataLoader():
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        
    def __iter__(self):
        self.indexes = np.arange(len(self.dataset))
        self.cursor = 0
        np.random.shuffle(self.indexes)
        return self

    def __next__(self):
        # 计算起始索引和终止索引
        begin = self.cursor
        end = self.cursor + self.batch_size
    
        # 若超出范围,抛出停止迭代异常
        if end > len(self.dataset):
            raise StopIteration
        
        # 更新游标位置
        self.cursor = end
        
        # 根据索引获取对应的数据
        batch_data = []
        for index in self.indexes[begin:end]:
            item = self.dataset[index]
            batch_data.append(item)
        
        return batch_data

if __name__ == "__main__":        
    images = [[f"image{i}", i] for i in range(10)]
    dataset = ImageDataset(images)
    loader = DataLoader(dataset, batch_size=5)
    
    for index, batch_data in enumerate(loader, 1):
        print(f"第{index}个批次:", batch_data)

代码中存在的问题:

当最后一个batch的样本数量不足 batch_size 时,比如总样本数不是 batch_size 的整数倍,不会返回最后一个不足的batch
改进后的 DataLoader

class DataLoader():
    def __init__(self,dataset, batch_size, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        
    def __iter__(self):
        """
        初始化迭代器, 每个epoch开始时自动调用
        """
        self.cursor = 0
        self.indexes = np.arange(len(self.dataset))
        if self.shuffle:
            np.random.shuffle(self.indexes)
        return self
    
    def __next__(self):
        """
        获取下一批次数据
        """
        begin = self.cursor
        end = self.cursor + self.batch_size
        
        # 当剩余数据不足一个批次时全部返回剩余数据
        if begin >= len(self.dataset):
            raise StopIteration
        
        end = min(end, len(self.dataset))
        self.cursor = end
        
        batch_data = []
        for index in self.indexes[begin:end]:
            item = self.dataset[index]
            batch_data.append(item)
        
        return batch_data

本文参考:

https://www.bilibili.com/video/BV12s4y1N72y/?spm_id_from=333.1387.favlist.content.click&vd_source=cf0b4c9c919d381324e8f3466e714d7a

相关文章:

  • zlm启用webrtc交叉编译指南
  • [免费]SpringBoot+Vue外卖(点餐)平台系统【论文+源码+SQL脚本】
  • 「出海匠」借助CloudPilot AI实现AWS降本60%,支撑AI电商高速增长
  • 鸿蒙开发-动画
  • C++核心机制-this 指针传递与内存布局分析
  • 读者、写者问题优化
  • 在AMGCL中使用多个GPU和多个计算节点求解大规模稀疏矩阵方程
  • JVM考古现场(十九):量子封神·用鸿蒙编译器重铸天道法则
  • 智能合约安全审计平台——以太坊虚拟机安全沙箱
  • Font Maker的成功之路:产品迭代与创新营销助力增长
  • 国达陶瓷重磅推出陶瓷罗马柱外墙整装尖端新产品“冠岩臻石”
  • Profibus DP主站转modbusTCP网关与dp从站通讯案例
  • 在vue项目中package.json中的scripts 中 dev:“xxx“中的xxx什么概念
  • 爬虫:一文掌握 curl-cffi 的详细使用(支持 TLS/JA3 指纹仿真的 cURL 库)
  • Nacos集群搭建和mysql持久化配置
  • 第三篇:[特殊字符] 深入理解MyBatis[特殊字符] 掌握MyBatis动态SQL——应对复杂查询的有力武器
  • 【vue】轮播图案例
  • 关于python字典的所有操作
  • 性能优化-Spring参数配置、数据库连接参数配置、JVM调优
  • 行锁(Row Locking)和MVCC(多版本并发控制)
  • 本周看啥|《乘风》迎来师姐们,《天赐》王蓉搭Ella
  • 党政机关停车场免费、食堂开放,多地“五一”游客服务暖心周到
  • 金融监管总局修订发布《行政处罚办法》,7月1日起施行
  • 赵乐际主持十四届全国人大常委会第十五次会议闭幕会并作讲话
  • 五一去哪儿| 追着花期去旅行,“赏花经济”绽放文旅新活力
  • 国际锐评:菲律宾“狐假虎威”把戏害的是谁?