PyTorch中 torch.utils.data.DataLoader 的详细解析和读取点云数据示例
一、DataLoader 是什么?
torch.utils.data.DataLoader
是 PyTorch 中用于加载数据的核心接口,它支持:
- 批量读取(batch)
- 数据打乱(shuffle)
- 多线程并行加载(num_workers)
- 自动将数据打包成 batch
- 数据预处理和增强(搭配 Dataset 使用)
二、常见参数详解
参数 | 含义 |
---|---|
dataset | 传入的 Dataset 对象(如自定义或 torchvision.datasets ) |
batch_size | 每个 batch 的样本数量 |
shuffle | 是否打乱数据(通常训练集为 True) |
num_workers | 并行加载数据的线程数(越大越快,但依机器决定) |
drop_last | 是否丢弃最后一个不足 batch_size 的 batch |
pin_memory | 若为 True,会将数据复制到 CUDA 的 page-locked 内存中(加速 GPU 训练) |
collate_fn | 自定义打包 batch 的函数(可用于变长序列、图神经网络等) |
sampler | 控制数据采样策略,不能与 shuffle 同时使用 |
persistent_workers | 若为 True,worker 在 epoch 间保持运行状态(提高效率,PyTorch 1.7+) |
三、基本使用示例
搭配 Dataset 使用
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self):self.data = [i for i in range(100)]def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]dataset = MyDataset()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2)for batch in loader:print(batch)
四、自定义 collate_fn 示例
适用于:变长数据(如文本、点云)或特殊处理需求
from torch.nn.utils.rnn import pad_sequencedef my_collate_fn(batch):# 假设每个样本是 list 或 tensor(变长)batch = [torch.tensor(item) for item in batch]padded = pad_sequence(batch, batch_first=True, padding_value=0)return paddedloader = DataLoader(dataset, batch_size=4, collate_fn=my_collate_fn)
五、使用注意事项
-
Windows 平台注意:
-
设置
num_workers > 0
时,必须使用:if __name__ == '__main__':DataLoader(...)
-
-
过多线程数可能导致瓶颈:
- 通常
num_workers = cpu_count() // 2
较稳定
- 通常
-
GPU 加速:
- 训练时推荐设置
pin_memory=True
可提高 GPU 训练数据传输效率。
- 训练时推荐设置
-
不要同时设置
shuffle=True
和sampler
:- 否则会报错,二者功能冲突。
六、训练中的典型使用方式
for epoch in range(num_epochs):for i, batch in enumerate(train_loader):inputs, labels = batchinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()
七、调试技巧与加速建议
场景 | 建议 |
---|---|
数据加载慢 | 增加 num_workers |
GPU 等数据 | 设置 pin_memory=True |
Dataset 中有耗时操作 | 考虑预处理或使用缓存 |
debug 模式 | 设置 num_workers=0 ,禁用多进程 |
八、与 TensorDataset、ImageFolder 配合
from torchvision.datasets import ImageFolder
from torchvision import transformstransform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),
])dataset = ImageFolder(root='your/image/folder', transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
九、点云数据处理场景应用实例
在 点云数据处理 场景中,使用 torch.utils.data.DataLoader
时,常遇到如下需求:
- 每帧点云大小不同(变长 Tensor)
- 点云数据 + 标签(如语义、实例)
- 使用
.bin
、.pcd
或.npy
等格式加载 - 数据增强(如旋转、裁剪、噪声)
- GPU 加速 + 批量训练
1. 点云数据 Dataset 示例(以 .npy
文件为例)
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoaderclass PointCloudDataset(Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.files = sorted([f for f in os.listdir(root_dir) if f.endswith('.npy')])self.transform = transformdef __len__(self):return len(self.files)def __getitem__(self, idx):point_cloud = np.load(os.path.join(self.root_dir, self.files[idx])) # shape: [N, 3] or [N, 6]point_cloud = torch.tensor(point_cloud, dtype=torch.float32)if self.transform:point_cloud = self.transform(point_cloud)return point_cloud
2. 自定义 collate_fn
(处理变长点云)
def collate_pointcloud_fn(batch):"""输入: List of [N_i x 3] tensors输出: - 合并后的 [B x N_max x 3] tensor- 每个样本的真实点数 list"""max_points = max(pc.shape[0] for pc in batch)padded = torch.zeros((len(batch), max_points, batch[0].shape[1]))lengths = []for i, pc in enumerate(batch):lengths.append(pc.shape[0])padded[i, :pc.shape[0], :] = pcreturn padded, torch.tensor(lengths)
3. 加载器构建示例
dataset = PointCloudDataset("/path/to/your/pointclouds")loader = DataLoader(dataset,batch_size=8,shuffle=True,num_workers=4,pin_memory=True,collate_fn=collate_pointcloud_fn
)for batch_points, batch_lengths in loader:# batch_points: [B, N_max, 3]# batch_lengths: [B]print(batch_points.shape)
4. 可选扩展功能
功能 | 实现方法 |
---|---|
点云旋转/缩放 | 自定义 transform (例如随机旋转矩阵乘点云) |
加载 .pcd | 使用 open3d , pypcd , 或 pclpy |
同时加载标签 | 在 Dataset 中返回 (point_cloud, label) ,修改 collate_fn |
voxel downsampling | 使用 open3d.geometry.VoxelDownSample |
GPU 加速 | point_cloud = point_cloud.cuda(non_blocking=True) |
5. 训练循环中使用
for epoch in range(num_epochs):for batch_pc, batch_len in loader:batch_pc = batch_pc.to(device)# 可用 batch_len 做 mask 或 attention maskout = model(batch_pc)...