项目介绍:图像分类项目的最小可用骨架--代码细节讲解
讲解代码如下:
'''创建数据集的类'''
import torch
from torch.utils.data import Dataset,DataLoader #用于处理数据集的
import numpy as np
from PIL import Image #
from torchvision import transforms #对数据进行处理工具 转换data_transforms = { #字典'train':transforms.Compose([ # 对图片做预处理的。组合,transforms.Resize([256,256]), #数据进行改变大小[256,256]transforms.ToTensor(), #数据转换为tensor,默认把通道维度放在前面]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),
}#数组增强,
#Dataset是用来处理数据的。 包含1张图片的数据(tensor),包含标签结果,并且能通过索引得到数据
class food_dataset(Dataset): # food_dataset是自己创建的类名称,可以改为你需要的名称def __init__(self, file_path,transform=None): #类的初始化,解析数据文件txtself.file_path = file_pathself.imgs = []#存储图片的路径self.labels = []#存储图片的标签结果self.transform = transformwith open(self.file_path) as f:#是把train.txt文件中图片的路径保存在 self.imgs,train.txt文件中标签保存在 self.labelssamples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path) #图像的路径self.labels.append(label) #标签,还不是tensor
#初始化:把图片目录加载到self,def __len__(self): #类实例化对象后,可以使用len函数测量对象的个数 ls=[12,3,4,4] len(training_data)return len(self.imgs)#training_data[1]def __getitem__(self, idx): #关键,可通过索引的形式获取每一个图片数据及标签image = Image.open(self.imgs[idx]) #读取到图片数据,还不是tensor,BGRif self.transform: #将pil图像数据转换为tensorimage = self.transform(image) #图像处理为256*256,转换为tenorlabel = self.labels[idx] #label还不是tensorlabel = torch.from_numpy(np.array(label,dtype = np.int64)) #label也转换为tensor,return image, label
#training_data包含了本次需要训练的全部数据集?
training_data = food_dataset(file_path = 'train.txt',transform = data_transforms['train'])
test_data = food_dataset(file_path = 'test.txt',transform = data_transforms['valid'])#training_data需要具备索引的功能,还要确保数据是tensor
train_dataloader = DataLoader(training_data, batch_size=64,shuffle=True)#64张图片为一个包,
test_dataloader = DataLoader(test_data, batch_size=64,shuffle=True)
数据索引功能的必要性
在机器学习训练中,
data loader
需通过索引批量获取数据(如每次取64个样本)。若数据对象不具备索引功能,则无法被data loader
有效管理。因此,需为数据对象添加__getitem__方法,使其支持索引操作。
def __getitem__(self, idx): #关键,可通过索引的形式获取每一个图片数据及标签image = Image.open(self.imgs[idx]) #读取到图片数据,还不是tensor,BGRif self.transform: #将pil图像数据转换为tensorimage = self.transform(image) #图像处理为256*256,转换为tenorlabel = self.labels[idx] #label还不是tensorlabel = torch.from_numpy(np.array(label,dtype = np.int64)) #label也转换为tensor,return image, label
图像预处理流程
- 裁剪图片:统一将图片尺寸调整为256×256,避免因输入尺寸不一致导致全连接层维度冲突。
- 数据转换:将图像数据转换为PyTorch的tensor格式,便于神经网络处理。
def __getitem__(self, idx): #关键,可通过索引的形式获取每一个图片数据及标签image = Image.open(self.imgs[idx]) #读取到图片数据,还不是tensor,BGRif self.transform: #将pil图像数据转换为tensorimage = self.transform(image) #图像处理为256*256,转换为tenorlabel = self.labels[idx] #label还不是tensorlabel = torch.from_numpy(np.array(label,dtype = np.int64)) #label也转换为tensor,return image, label
Pillow库的使用场景
- Pillow库与OpenCV类似,但功能更简单,适合基础图像操作(如裁剪、格式转换)。因其轻量高效,常用于无需复杂算法的预处理任务。
def __init__(self, file_path,transform=None): #类的初始化,解析数据文件txtself.file_path = file_pathself.imgs = []#存储图片的路径self.labels = []#存储图片的标签结果self.transform = transformwith open(self.file_path) as f:#是把train.txt文件中图片的路径保存在 self.imgs,train.txt文件中标签保存在 self.labelssamples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path) #图像的路径self.labels.append(label) #标签,还不是tensor
- Pillow库与OpenCV类似,但功能更简单,适合基础图像操作(如裁剪、格式转换)。因其轻量高效,常用于无需复杂算法的预处理任务。
自定义数据集类实现
- 继承自父类
DataSet
,核心方法__getitem__
根据索引返回预处理后的数据(tensor格式的图片及标签)。 - 初始化阶段读取txt文件,存储图片路径和标签,为后续索引提供依据。
#training_data包含了本次需要训练的全部数据集? training_data = food_dataset(file_path = 'train.txt',transform = data_transforms['train']) # test_data = food_dataset(file_path = 'test.txt',transform = data_transforms['valid'])
- 继承自父类
标签处理逻辑
- 原始标签为字符串类型,需先转为NumPy数组,再转为整数类型的tensor,确保与模型输入匹配。