成都营销型网站建设价格北京搜索引擎优化经理
dataset数据集
作用:
- 存储数据集的信息
- 获取数据集长度
__len__
- 获取数据集某特定条目的内容
__getitem__
dataloader 数据加载器
作用:
- 从数据集中随机加载数据, 并拼接为一个 batch
- 实现迭代器, 可以使用时, 迭代获取数据内容
代码实现:
import numpy as np
class ImageDataset():def __init__(self, raw_data):"""数据集初始化"""self.raw_data = raw_datadef __len__(self):"""返回数据集的长度"""return len(self.raw_data)def __getitem__(self, index):"""根据索引获取数据集中某一条数据"""image, label = self.raw_data[index]return image, labelclass DataLoader():def __init__(self, dataset, batch_size):self.dataset = datasetself.batch_size = batch_sizedef __iter__(self):self.indexes = np.arange(len(self.dataset))self.cursor = 0np.random.shuffle(self.indexes)return selfdef __next__(self):# 计算起始索引和终止索引begin = self.cursorend = self.cursor + self.batch_size# 若超出范围,抛出停止迭代异常if end > len(self.dataset):raise StopIteration# 更新游标位置self.cursor = end# 根据索引获取对应的数据batch_data = []for index in self.indexes[begin:end]:item = self.dataset[index]batch_data.append(item)return batch_dataif __name__ == "__main__": images = [[f"image{i}", i] for i in range(10)]dataset = ImageDataset(images)loader = DataLoader(dataset, batch_size=5)for index, batch_data in enumerate(loader, 1):print(f"第{index}个批次:", batch_data)
代码中存在的问题:
当最后一个batch的样本数量不足 batch_size
时,比如总样本数不是 batch_size
的整数倍,不会返回最后一个不足的batch
改进后的 DataLoader:
class DataLoader():def __init__(self,dataset, batch_size, shuffle=True):self.dataset = datasetself.batch_size = batch_sizeself.shuffle = shuffledef __iter__(self):"""初始化迭代器, 每个epoch开始时自动调用"""self.cursor = 0self.indexes = np.arange(len(self.dataset))if self.shuffle:np.random.shuffle(self.indexes)return selfdef __next__(self):"""获取下一批次数据"""begin = self.cursorend = self.cursor + self.batch_size# 当剩余数据不足一个批次时全部返回剩余数据if begin >= len(self.dataset):raise StopIterationend = min(end, len(self.dataset))self.cursor = endbatch_data = []for index in self.indexes[begin:end]:item = self.dataset[index]batch_data.append(item)return batch_data
本文参考:
https://www.bilibili.com/video/BV12s4y1N72y/?spm_id_from=333.1387.favlist.content.click&vd_source=cf0b4c9c919d381324e8f3466e714d7a