Python训练Day38
@浙大疏锦行
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
在遇到大规模数据集时,显存常常无法一次性存储所有数据,所以需要使用分批训练的方法。为此,PyTorch提供了DataLoader类,该类可以自动将数据集切分为多个批次batch,并支持多线程加载数据。此外,还存在Dataset类,该类可以定义数据集的读取方式和预处理方式。
1. DataLoader类:决定数据如何加载
2. Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理。
一、Dataset类
PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:
- __len__():返回数据集的样本总数。
- __getitem__(idx):根据索引idx返回对应样本的数据和标签。
PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。
__getitem__方法
__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。
# 示例代码
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2]) # 输出:30
__len__方法
__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj)) # 输出:5
二、Dataloader类
# 3. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)