深度学习PyTorch之数据加载DataLoader
深度学习pytorch之简单方法自定义9类卷积即插即用
文章目录
- 数据加载基础架构
- 1、Dataset类详解
- 2、DataLoader核心参数解析
- 3、数据增强
数据加载基础架构
核心类关系图
torch.utils.data
├── Dataset (抽象基类)
├── DataLoader (数据加载器)
├── Sampler (采样策略)
├── BatchSampler (批量采样)
└── IterableDataset (流式数据集)
1、Dataset类详解
自定义数据集模板
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.file_list = glob.glob(f"{data_dir}/*.jpg")
self.labels = self._load_labels()
self.transform = transform
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
image = Image.open(self.file_list[idx])
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
def _load_labels(self):
# 实现标签加载逻辑
return [...]
关键方法说明:
-
init: 初始化数据路径、预处理方法等
-
len: 返回数据集样本总数
-
getitem: 根据索引返回单个样本数据
2、DataLoader核心参数解析
基础配置示例
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset=dataset,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=False
)
参数详解表
参数 | 类型 | 默认值 | 作用 |
---|---|---|---|
batch_size | int | 1 | 批量大小 |
shuffle | bool | False | 是否打乱数据顺序 |
sampler | Sampler | None | 自定义采样策略 |
batch_sampler | Sampler | None | 批量采样策略 |
num_workers | int | 0 | 数据加载子进程数 |
collate_fn | callable | default_collate | 批量样本处理函数 |
pin_memory | bool | False | 是否锁页内存加速传输 |
drop_last | bool | False | 是否丢弃最后不完整批次 |
3、数据增强
深度学习 PyTorch 中 18 种数据增强策略与实现