DataSet-深度学习中的常见类
深度学习中Dataset类通用的架构思路
Dataset 类设计的必备部分
1. 初始化 __init__
- 配置和路径管理:保存
config
,区分train/val/test
路径。 - 加载原始数据:CSV、JSON、Numpy、Parquet 等。
- 预处理器/归一化器:如
StandardScaler
,或者 Tokenizer(在 NLP 任务里)。 - 准备辅助信息:比如 meta 特征、文本 embedding。
- 构造样本列表(self.samples):保证后面取样时直接
O(1)
访问。
2. 数据预处理
- normalize / inverse_transform:数值数据标准化和反变换。
- tokenize / pad:文本分词、对齐。
- feature engineering:特征拼接、缺失值处理。
3. 核心接口
__len__
: 返回数据集样本数。__getitem__
: 返回一个样本(通常是(features, label)
的 tuple 或 dict)。
4. 可选接口
get_scaler()
: 返回归一化器。get_vocab()
: NLP 任务里返回词表。collate_fn
: 定义 batch 内如何拼接(特别是变长序列)。save_cache
/load_cache
: 大数据集可以存缓存,避免每次都重新处理。
5. 继承关系
-
BaseDataset:负责
- 通用逻辑(加载文件、归一化、拼装 sample)。
- 提供钩子函数,比如
load_paths(flag)
、process_sample(sample)
。
-
子类:只需要实现 路径差异 或 样本加工方式差异。
通用代码结构示意
class BaseDataset(Dataset):def __init__(self, config, flag="train", scaler=None):self.config = configself.flag = flagself.scaler = scaler or StandardScaler()self.samples = []self._load_data()self._build_samples()def _load_data(self):"""子类可重写,加载原始数据"""raise NotImplementedErrordef _build_samples(self):"""子类可重写,拼装每个样本的x, y, feats"""raise NotImplementedErrordef __len__(self):return len(self.samples)def __getitem__(self, idx):return self.samples[idx]def get_scaler(self):return self.scalerdef inverse_transform(self, x):return x * self.std + self.mean
子类只管:
class ElectricityDataset(BaseDataset):def _load_data(self):# 只写路径和文件加载逻辑passdef _build_samples(self):# 根据任务需要定义样本结构pass
调用示例
data_config = {"root": "data/electricity/","train_file": "train.json","train_meta_file": "train_meta.npy","train_news_file": "train_news.npy"
}train_config = {"batch_size": 64,"learning_rate": 1e-3,"epochs": 20
}train_ds = ElectricityDataset(data_config, flag="train")train_loader = DataLoader(train_ds,batch_size=train_config["batch_size"],shuffle=True,collate_fn=custom_collate_fn
))