Dataset类代码实战
实战一
label即是文件夹名
from torch.utils.data import Dataset, DataLoader
import cv2
from PIL import Image
import os
class MyDataset(Dataset):def __init__(self,root_dir,label_dir):self.root_dir = root_dir #根目录self.label_dir = label_dir #两个数据目录,目录名即为label值self.path=os.path.join(self.root_dir,self.label_dir) # 通过拼接,得到目标目录self.img_path=os.listdir(self.path)#将这个目录转化为list列表,值为每个图片文件的文件名def __getitem__(self, index):img_name=self.img_path[index]img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)img=Image.open(img_item_path)label=self.label_dirreturn img,labeldef __len__(self):return len(self.img_path)root_dair="dataset/train"
ants_label_dair="ants"
bees_label_dair="bees"
ants_dataset = MyDataset(root_dair,ants_label_dair)
bees_dataset = MyDataset(root_dair,bees_label_dair)train_dataset = ants_dataset+bees_dataset # 数据集拼接
实战二
label单独存放在一个文件里
from torch.utils.data import Dataset
from PIL import Image
import osclass MyDataset(Dataset):def __init__(self, root_dir, image_dir, label_dir):# 数据集根目录路径self.root_dir = root_dir# 图像文件夹名称self.image_dir = image_dir# 标签文件夹名称self.label_dir = label_dir# 拼接图像文件夹完整路径self.image_path = os.path.join(self.root_dir, self.image_dir)# 拼接标签文件夹完整路径self.label_path = os.path.join(self.root_dir, self.label_dir)# 获取图像文件列表self.image_names = os.listdir(self.image_path)# 获取标签文件列表self.label_names = os.listdir(self.label_path)def __getitem__(self, index):# 获取指定索引的图像文件名img_name = self.image_names[index]# 拼接图像完整路径img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)# 打开并读取图像img = Image.open(img_item_path)# 获取对应索引的标签文件名label_name = self.label_names[index]# 拼接标签完整路径label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)# 打开并读取标签文件内容(通常为类别名或ID)with open(label_item_path, 'r') as f:label = f.read().strip()# 返回图像和标签元组return img, labeldef __len__(self):# 返回数据集样本总数return len(self.image_names)# 数据集配置
root_dir = "dataset_practice/train" # 训练集根目录
ants_img_dir = "ants_image" # 蚂蚁图像文件夹
ants_label_dir = "ants_label" # 蚂蚁标签文件夹
bees_img_dir = "bees_image" # 蜜蜂图像文件夹
bees_label_dir = "bees_label" # 蜜蜂标签文件夹# 创建蚂蚁和蜜蜂数据集实例
ants_dataset = MyDataset(root_dir, ants_img_dir, ants_label_dir)
bees_dataset = MyDataset(root_dir, bees_img_dir, bees_label_dir)