每天五分钟深度学习框架PyTorch:基于Dataset封装自定义数据集
本文重点
我们前面使用过torchvision中的数据集,我们可以直接使用torchvison来加载,然后使用dataLoader来获取batch个数据,用于模型的训练和测试。现在的问题是如果我们的所要使用的数据pytorch中没有进行封装怎么办?
Dataset
我们还记得torchvision中封装的数据集dataset继承Dataset,所以我们可以继承Dataset封装处理我们自己数据集的类。
我们需要定义一个类,当然这个类要继承抽象类Dataset,此时我们需要完成三个方法,第一个初始化方法__init__,这个方法完成读取数据集,__len__该方法提供数据集大小和__getitem__通过整数索引返回一个样本
大概的模式
高负载的操作放在__getitem__中,如加载图片,图片transform等,dataset中应尽量只包含只读对象,避免修改任何可变对象,这样多线程操作就不会出现问题。
from torch.utils.data import Datasetclass myDataset(Dataset): def __init__ (self, csv_file, txt_file, root_dir, other_file) : self.csv_data = pd.read_csv(csv_file) with open(txt_file, 'r' ) as f: