【DataLoader的使用】
一、Dataset与DataLoader
- Pytorch中的 torch.utils.data 提供了两个抽象类:Dataset 和 Dataloader。Dataset 允许你自定义自己的数据集,用来存储样本及其对应的标签。而 Dataloader 则是在 Dataset 的基础上将其包装为一个可迭代对象,以便我们更方便地(小批量)访问数据集。
- 通俗来讲,如果说有一家包子店,Dataset负责处理每个包子,什么馅料的,大包子还是小包子。而DataLoader就是相当于服务员,它只负责怎么拿包子,送给客人,不需要关心怎么做包子。
二、实战
- 引入必要库
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torch.utils.data import DataLoader
- 数据预处理
data_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
统一图像尺寸(不然会报错),把PIL格式转为张量
- 创建ImageFolder数据集实例
test_set=torchvision.datasets.ImageFolder(
root=r'D:/My_Work/StudyDeepLearning/day3_code/dataset/test',
transform=data_transform)
指定数据集根目录路径,应用上面的数据转换
- 创建DataLoader实例
test_loader=DataLoader(
test_set,
batch_size=32,
shuffle=True,
num_workers=0,
drop_last=False)
dataset
:指定加载的数据集对象,包含所有样本和标签。batch_size
:设置每批次加载的样本数(如32),影响内存使用和训练效率。shuffle
:是否打乱数据顺序(训练建议True
,测试建议False
)。num_workers
:数据加载子进程数(0
为主进程,多核可设4-8加速加载)。drop_last
:是否丢弃末尾不足batch_size
的批次(False
保留,True
丢弃)。pin_memory
(可选):锁页内存加速GPU传输(GPU训练时建议True
)。persistent_workers
(可选):保持子进程存活,减少重复初始化开销(PyTorch 1.7+)。
- 放到tensorboard测试查看
#取测试集中第一张图片以及label
img,label=test_set[0]
print(img.shape)
print(label)writer=SummaryWriter("dataloader")
for epoch in range(2):step=0for data in test_loader:imgs,labels=datawriter.add_images("Epoch:{}".format(epoch),imgs,step)step+=1
writer.close()