网站推广10大方法短视频平台推广
什么是Dataset和Dataloader
- Dataset指定了数据集包含了什么,可以是自定义数据集,也可以是以及官方数据集
- Dataloader指定了这个数据集应该以怎样的方式进行加载
定义Dataset
自定义的Dataset格式如下所示
# -*- coding: utf-8 -*-
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self):# 定义了数据集包含了什么东西self.x = []self.y = []def __len__(self):# 返回数据集的总长度return len(...)def __getitem__(self, idx):# 当数据集被读取时,需要返回的数据...return self.x[idx], self.y[idx]
案例1:导入两个列表到Dataset
from torch.utils.data import Dataset, DataLoaderclass NewDataset(Dataset):def __init__(self):self.x = [i for i in range(12)]self.y = [i * 2 for i in range(12)]def __getitem__(self, item):return self.x[item], self.y[item]def __len__(self):return len(self.x)if __name__ == '__main__':newdataset = NewDataset()newdataloader = DataLoader(newdataset)for x_i, y_i in newdataloader:print(x_i, y_i)newdataloader = DataLoader(newdataset, batch_size=2)for x_i, y_i in newdataloader:print(x_i, y_i)newdataloader = DataLoader(newdataset, batch_size=4, shuffle=True)for x_i, y_i in newdataloader:print(x_i, y_i)
案例2:导入Excel数据到Dataset
# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader, Datasetclass MyDataset(Dataset):def __init__(self):filename = "./anli2/data.xlsx"data = pd.read_excel(filename)self.x1 = data['x1']self.x2 = data['x2']self.x3 = data['x3']self.x4 = data['x4']self.y = data['y']def __len__(self):return len(self.x1)def __getitem__(self, item):return self.x1[item], self.x2[item], self.x3[item], self.x4[item], self.y[item]if __name__ == '__main__':mydataset = MyDataset()mydataloader = DataLoader(mydataset, shuffle=True, batch_size=4)for x1, x2, x3, x4, y in mydataloader:print(f"x1={x1},x2={x2},x3={x3},x4={x4},y={y}")
案例3:导入图像数据集
# -*- coding: utf-8 -*-
import os
import cv2 as cv
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as npclass MyImageDataset(Dataset):def __init__(self):image_root = r"anli3/image"self.file_path_list = []dir_name = []self.labels = []for root, dirs, files in os.walk(image_root):if dirs:dir_name = dirsfor file_i in files:file_i_full_path = os.path.join(root, file_i)self.file_path_list.append(file_i_full_path)label = root.split(os.sep)[-1]self.labels.append(label)def __len__(self):return len(self.file_path_list)def __getitem__(self, item):img = cv.imread(self.file_path_list[item])img = cv.resize(img, dsize=(256, 256))# 原先的shape为[1,256,256,3]# 要将3调换到1的后面img = np.transpose(img, (2, 1, 0))img_tensor = torch.from_numpy(img)label = self.labels[item]return img_tensor, labelif __name__ == '__main__':mydataset = MyImageDataset()mydataloader = DataLoader(mydataset, batch_size=4, shuffle=True, num_workers=4)for x_i, y_i in mydataloader:print(x_i.shape, y_i)
for root, dirs, files in os.walk(image_root):
它是 Python 中 os
模块的一部分。os.walk()
递归遍历指定目录及其子目录,返回三个值:根目录、子目录和文件列表
label = root.split(os.sep)[-1]
使用文件路径分隔符(os.sep
)将字符串 root
分割成一个列表。os.sep
是一个在不同操作系统中定义的路径分隔符,Windows 中为 \
,而在 Unix/Linux 中为 /
。