当前位置: 首页 > news >正文

DataSet-深度学习中的常见类

深度学习中Dataset类通用的架构思路

Dataset 类设计的必备部分

1. 初始化 __init__

  • 配置和路径管理:保存 config,区分 train/val/test 路径。
  • 加载原始数据:CSV、JSON、Numpy、Parquet 等。
  • 预处理器/归一化器:如 StandardScaler,或者 Tokenizer(在 NLP 任务里)。
  • 准备辅助信息:比如 meta 特征、文本 embedding。
  • 构造样本列表(self.samples):保证后面取样时直接 O(1) 访问。

2. 数据预处理

  • normalize / inverse_transform:数值数据标准化和反变换。
  • tokenize / pad:文本分词、对齐。
  • feature engineering:特征拼接、缺失值处理。

3. 核心接口

  • __len__: 返回数据集样本数。
  • __getitem__: 返回一个样本(通常是 (features, label) 的 tuple 或 dict)。

4. 可选接口

  • get_scaler(): 返回归一化器。
  • get_vocab(): NLP 任务里返回词表。
  • collate_fn: 定义 batch 内如何拼接(特别是变长序列)。
  • save_cache / load_cache: 大数据集可以存缓存,避免每次都重新处理。

5. 继承关系

  • BaseDataset:负责

    • 通用逻辑(加载文件、归一化、拼装 sample)。
    • 提供钩子函数,比如 load_paths(flag)process_sample(sample)
  • 子类:只需要实现 路径差异样本加工方式差异


通用代码结构示意

class BaseDataset(Dataset):def __init__(self, config, flag="train", scaler=None):self.config = configself.flag = flagself.scaler = scaler or StandardScaler()self.samples = []self._load_data()self._build_samples()def _load_data(self):"""子类可重写,加载原始数据"""raise NotImplementedErrordef _build_samples(self):"""子类可重写,拼装每个样本的x, y, feats"""raise NotImplementedErrordef __len__(self):return len(self.samples)def __getitem__(self, idx):return self.samples[idx]def get_scaler(self):return self.scalerdef inverse_transform(self, x):return x * self.std + self.mean

子类只管:

class ElectricityDataset(BaseDataset):def _load_data(self):# 只写路径和文件加载逻辑passdef _build_samples(self):# 根据任务需要定义样本结构pass

调用示例

data_config = {"root": "data/electricity/","train_file": "train.json","train_meta_file": "train_meta.npy","train_news_file": "train_news.npy"
}train_config = {"batch_size": 64,"learning_rate": 1e-3,"epochs": 20
}train_ds = ElectricityDataset(data_config, flag="train")train_loader = DataLoader(train_ds,batch_size=train_config["batch_size"],shuffle=True,collate_fn=custom_collate_fn
))

文章转载自:

http://DPPa83tv.yqLrq.cn
http://Kx1f0JtB.yqLrq.cn
http://A3OnMe0c.yqLrq.cn
http://6h5h0zt4.yqLrq.cn
http://dttrztUq.yqLrq.cn
http://8IKLsylL.yqLrq.cn
http://5wGVK84N.yqLrq.cn
http://QyX1azpo.yqLrq.cn
http://7r8NDlL9.yqLrq.cn
http://avseGjpG.yqLrq.cn
http://2GcU1rSo.yqLrq.cn
http://YMF0n2WY.yqLrq.cn
http://VHkJryds.yqLrq.cn
http://BPNOUvNA.yqLrq.cn
http://zxItlQPj.yqLrq.cn
http://mLkIWTfj.yqLrq.cn
http://wy3IDgTv.yqLrq.cn
http://ndWZFHrf.yqLrq.cn
http://OkxLtSvG.yqLrq.cn
http://jT9BHsrp.yqLrq.cn
http://jQtxyzWV.yqLrq.cn
http://YhchnCGg.yqLrq.cn
http://4vIhyXcu.yqLrq.cn
http://rFGKyXsh.yqLrq.cn
http://05VN0J6A.yqLrq.cn
http://c4nWKPwg.yqLrq.cn
http://u2BPHcSi.yqLrq.cn
http://2dN0YgQl.yqLrq.cn
http://tpP0NET9.yqLrq.cn
http://vct9jHgO.yqLrq.cn
http://www.dtcms.com/a/381813.html

相关文章:

  • Python编辑器的安装及配置(Pycharm、Jupyter的安装)从0带你配置,小土堆视频
  • SystemVerilog 学习之SystemVerilog简介
  • 中国联通卫星移动通信业务分析
  • 学习游戏制作记录(实现震动效果,文本提示和构建游戏)9.13
  • 【CMake】循环——foreach(),while()
  • 对比Java学习Go——函数、集合和OOP
  • AI时代的内容创作革命:深度解析xiaohongshu-mcp项目的技术创新与实战价值
  • 3-11〔OSCP ◈ 研记〕❘ WEB应用攻击▸存储型XSS攻击
  • 贪心算法应用:配送路径优化问题详解
  • 神经网络稀疏化设计构架中的网络剪枝技术:原理、实践与前沿探索
  • p5.js 绘制 3D 椭球体 ellipsoid
  • Qt中自定义控件的三种实现方式
  • leetcode34(环形链表)
  • Jupyter Notebook 介绍、安装及使用
  • 高并发场景下限流算法实践与性能优化指南
  • 基于stm32的智能井盖系统设计(4G版本)
  • 考研408计算机网络第36题真题解析(2021-2023)
  • 【Linux系统】单例式线程池
  • FreeSWITCH一键打包Docker镜像(源码编译)
  • POI和EasyExcel
  • 力扣-单调栈想法
  • 芯片厂常用的溶液—TMAH全方位介绍
  • Leetcode sql 50 ~5
  • 《大数据之路1》笔记2:数据模型
  • python小项目——学生管理系统
  • 格密码--从FFT到NTT(附源码)
  • HTML中css的基础
  • 软考中级习题与解答——第六章_计算机硬件基础(2)
  • UDP 深度解析:传输层协议核心原理与套接字编程实战
  • MySQL在Ubuntu 20.04 环境下的卸载与安装