Dataset和Dataloader
什么是Dataset和Dataloader
- Dataset指定了数据集包含了什么,可以是自定义数据集,也可以是以及官方数据集
- Dataloader指定了这个数据集应该以怎样的方式进行加载
定义Dataset
自定义的Dataset格式如下所示
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
# 定义了数据集包含了什么东西
self.x = []
self.y = []
def __len__(self):
# 返回数据集的总长度
return len(...)
def __getitem__(self, idx):
# 当数据集被读取时,需要返回的数据
...
return self.x[idx], self.y[idx]
案例1:导入两个列表到Dataset
from torch.utils.data import Dataset, DataLoader
class NewDataset(Dataset):
def __init__(self):
self.x = [i for i in range(12)]
self.y = [i * 2 for i in range(12)]
def __getitem__(self, item):
return self.x[item], self.y[item]
def __len__(self):
return len(self.x)
if __name__ == '__main__':
newdataset = NewDataset()
newdataloader = DataLoader(newdataset)
for x_i, y_i in newdataloader:
print(x_i, y_i)
newdataloader = DataLoader(newdataset, batch_size=2)
for x_i, y_i in newdataloader:
print(x_i, y_i)
newdataloader = DataLoader(newdataset, batch_size=4, shuffle=True)
for x_i, y_i in newdataloader:
print(x_i, y_i)
案例2:导入Excel数据到Dataset
# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
filename = "./anli2/data.xlsx"
data = pd.read_excel(filename)
self.x1 = data['x1']
self.x2 = data['x2']
self.x3 = data['x3']
self.x4 = data['x4']
self.y = data['y']
def __len__(self):
return len(self.x1)
def __getitem__(self, item):
return self.x1[item], self.x2[item], self.x3[item], self.x4[item], self.y[item]
if __name__ == '__main__':
mydataset = MyDataset()
mydataloader = DataLoader(mydataset, shuffle=True, batch_size=4)
for x1, x2, x3, x4, y in mydataloader:
print(f"x1={x1},x2={x2},x3={x3},x4={x4},y={y}")
案例3:导入图像数据集
# -*- coding: utf-8 -*-
import os
import cv2 as cv
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
class MyImageDataset(Dataset):
def __init__(self):
image_root = r"anli3/image"
self.file_path_list = []
dir_name = []
self.labels = []
for root, dirs, files in os.walk(image_root):
if dirs:
dir_name = dirs
for file_i in files:
file_i_full_path = os.path.join(root, file_i)
self.file_path_list.append(file_i_full_path)
label = root.split(os.sep)[-1]
self.labels.append(label)
def __len__(self):
return len(self.file_path_list)
def __getitem__(self, item):
img = cv.imread(self.file_path_list[item])
img = cv.resize(img, dsize=(256, 256))
# 原先的shape为[1,256,256,3]
# 要将3调换到1的后面
img = np.transpose(img, (2, 1, 0))
img_tensor = torch.from_numpy(img)
label = self.labels[item]
return img_tensor, label
if __name__ == '__main__':
mydataset = MyImageDataset()
mydataloader = DataLoader(mydataset, batch_size=4, shuffle=True, num_workers=4)
for x_i, y_i in mydataloader:
print(x_i.shape, y_i)
for root, dirs, files in os.walk(image_root):
它是 Python 中 os
模块的一部分。os.walk()
递归遍历指定目录及其子目录,返回三个值:根目录、子目录和文件列表
label = root.split(os.sep)[-1]
使用文件路径分隔符(os.sep
)将字符串 root
分割成一个列表。os.sep
是一个在不同操作系统中定义的路径分隔符,Windows 中为 \
,而在 Unix/Linux 中为 /
。