当前位置: 首页 > news >正文

深度学习之第五课卷积神经网络 (CNN)如何训练自己的数据集(食物分类)

简介

之前一直使用的是现有人家的数据集,现在我们将使用自己的数据集进行训练。

基于卷积神经网络 (CNN) 的 MNIST 手写数字识别模型

一、训练自己数据集

1.数据预处理

我们现在有这样的数据集如下图:

每一个文件夹里面有着对应的图片。我们要将这些图片转换成数据集的标准格式(也就是x、y标签)

def train_test_file(root, dir):# 创建存储图像路径和标签的文本文件file_txt = open(dir + '.txt', 'w')path = os.path.join(root, dir)  # 拼接完整路径(根目录+训练/测试文件夹)# 遍历文件夹结构for roots, directories, files in os.walk(path):# 记录类别目录(如"苹果"、"香蕉"等文件夹名)if len(directories) != 0:dirs = directories  # 保存所有类别文件夹名称else:# 处理图像文件:获取当前文件夹名(即类别名)now_dir = roots.split('\\')  # 按路径分隔符拆分for file in files:# 拼接图像完整路径path_1 = os.path.join(roots, file)print(path_1)  # 打印图像路径(调试用)# 写入格式:图像路径 + 空格 + 类别索引(如"apple.jpg 0")file_txt.write(path_1 + ' ' + str(dirs.index(now_dir[-1])) + '\n')file_txt.close()  # 关闭文件# 保存类别名称到class_names.txt(方便后续查看类别对应关系)with open('class_names.txt', 'w', encoding='gbk') as f:f.write('\n'.join(dir))print(f"已生成{dir}.txt,类别列表:{dir}")# 调用函数生成训练集和测试集文件
train_test_file(r'.\food_dataset\food_dataset', 'train')
train_test_file(r'.\food_dataset\food_dataset', 'test')
  • 这个函数遍历指定目录,生成图像路径和对应标签的文本文件
  • 每个图像路径后面跟着它所属类别的索引(用于训练时的标签)
  • 同时生成类别名称文件 class_names.txt

这样我们就通过代码生成下面这些文件标签,模型可以读取这些文件,前面是图片的地址,这样模型就能通过地址去读取对应的图片,后面的0等数字就是对应的标签。

        对于class_names.txt文件后面我们可以输入一张自己的图片进行检测调用,里面包含了预测的不同食物名称。

2.定义数据转换

        我们的数据集图片大小不一样我们要进行将图片大小统一,如果数据集图片大小不同意,我们的模型中全链接层就不能确定,导致我们的参数的个数都不能确定下来。

data_transforms={     # 字典存储不同的数据转换方式'train':transforms.Compose([  # 组合多个转换操作transforms.Resize([256,256]),  # 调整图像大小为256x256transforms.ToTensor(),  # 转换为Tensor格式]),'valid':transforms.Compose([transforms.Resize([256, 256]),transforms.ToTensor(),]),
}

3. 自定义数据集类

        这里就是调用dataset类,让后面的dataloader去通过我们创建的train.txt和 test.txt读取自己的数据集图片,然后返回图片跟标签类别

class food_dataset(Dataset):  # 继承PyTorch的Dataset类def __init__(self, file_path, transform=None):# 初始化:读取文件列表并存储图像路径和标签self.file_path = file_path  # 数据文件路径(如train.txt)self.imgs = []  # 存储图像路径列表self.labels = []  # 存储标签列表self.transform = transform  # 数据转换函数# 读取train.txt/test.txt文件with open(file_path, 'r') as f:# 按行拆分,每行按空格分割为[图像路径, 标签]samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)  # 存储图像路径self.labels.append(label)  # 存储标签(字符串格式)def __len__(self):# 返回数据集总样本数(必须实现的方法)return len(self.imgs)def __getitem__(self, idx):# 根据索引获取单个样本(必须实现的方法)# 1. 读取图像image = Image.open(self.imgs[idx])  # 用PIL打开图像# 2. 应用预处理if self.transform:image = self.transform(image)  # 转换为张量并调整大小# 3. 处理标签:转换为整数张量label = self.labels[idx]  # 原始标签是字符串label = torch.from_numpy(np.array(label, dtype=np.int64))  # 转换为int64类型张量return image, label  # 返回(图像张量,标签张量)
  • 为什么需要自定义 Dataset:PyTorch 的 DataLoader 需要通过 Dataset 类加载数据,自定义类可灵活适配不同数据格式。
  • 核心方法
    • __init__:初始化时读取文件列表,无需一次性加载所有图像(节省内存)。
    • __len__:让 DataLoader 知道数据集大小,用于迭代。
    • __getitem__:按需加载单个样本(延迟加载),避免内存溢出。

4.创建数据加载器(DataLoader)

# 实例化数据集
train_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['train'])# 创建数据加载器(批处理、打乱数据)
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)
  • DataLoader 作用
    • 批量加载数据(batch_size=32:每次加载 32 张图像)。
    • 打乱训练数据(shuffle=True:每个 epoch 重新打乱顺序)。
    • 支持多线程加载(默认参数,加速数据读取)。
  • 输出格式:每次迭代返回(images, labels),其中images形状为(32, 3, 256, 256)(批次大小 × 通道数 × 高 × 宽),labels形状为(32,)

5.选择计算设备

# 自动选择最优计算设备
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
  • 优先级:GPU (cuda) > 苹果芯片 GPU (mps) > CPU。
  • 作用:将模型和数据迁移到指定设备,加速计算(GPU 比 CPU 快 10~100 倍)

6.定义 CNN 模型结构

class CNN(nn.Module):  # 继承PyTorch的神经网络基类def __init__(self):super(CNN, self).__init__()  # 初始化父类# 第一个卷积块:卷积+激活+池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3,  # 输入通道数(RGB图像为3)out_channels=32,  # 输出通道数(卷积核数量)kernel_size=5,  # 卷积核大小(5×5)stride=1,  # 步长(每次滑动1像素)padding=2  # 填充(边缘补0,保持输出尺寸与输入一致)),nn.ReLU(),  # 激活函数(引入非线性)nn.MaxPool2d(2)  # 最大池化(2×2窗口,输出尺寸减半))# 第二个卷积块:卷积+激活+卷积+池化self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 输入32通道,输出64通道nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),  # 输入64通道,输出64通道nn.MaxPool2d(2)  # 再次池化,尺寸减半)# 第三个卷积块:卷积+激活(无池化)self.conv3 = nn.Sequential(nn.Conv2d(64, 128, 5, 1, 2),  # 输入64通道,输出128通道nn.ReLU())# 全连接层:将特征映射到20个类别self.out = nn.Linear(128 * 64 * 64, 20)  # 输入维度=128通道×64×64特征图def forward(self, x):  # 前向传播(定义数据流向)x = self.conv1(x)  # 经过第一个卷积块:输出形状(32, 128, 128)x = self.conv2(x)  # 经过第二个卷积块:输出形状(64, 64, 64)x = self.conv3(x)  # 经过第三个卷积块:输出形状(128, 64, 64)x = x.view(x.size(0), -1)  # 展平特征图:(batch_size, 128×64×64)output = self.out(x)  # 全连接层输出:(batch_size, 20)return output
  • 模型结构解析
    • 卷积层:通过滑动窗口提取图像局部特征(如边缘、纹理)。
    • 池化层:降低特征图尺寸,减少参数数量(如 2×2 池化将尺寸减半)。
    • 全连接层:将卷积提取的特征映射到类别空间(20 个类别)。
  • 尺寸计算:输入 256×256 图像经过两次池化(每次减半)后,得到 64×64 特征图,最终展平为128×64×64=524,288维向量。

7.训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 切换到训练模式(启用 dropout/batchnorm等训练特有的层)batch_size_num = 1  # 记录当前批次编号for X, y in dataloader:  # 迭代所有批次# 将数据迁移到计算设备X, y = X.to(device), y.to(device)# 1. 前向传播:计算预测值pred = model.forward(X)  # 等价于 model(X)# 2. 计算损失loss = loss_fn(pred, y)  # 交叉熵损失:比较预测值与真实标签# 3. 反向传播与参数更新optimizer.zero_grad()  # 清空历史梯度(避免累积)loss.backward()  # 计算梯度(反向传播)optimizer.step()  # 根据梯度更新参数(梯度下降)# 打印训练进度(每32个批次)loss = loss.item()  # 提取损失值(从张量转为Python数值)if batch_size_num % 32 == 0:print(f"loss: {loss:>7f} [批次: {batch_size_num}]")batch_size_num += 1
  • 核心流程:前向传播→计算损失→反向传播→更新参数(标准的深度学习训练循环)。
  • 细节说明
    • model.train():启用训练模式(例如 BatchNorm 层会计算均值和方差)。
    • optimizer.zero_grad():必须清空梯度,否则会累积上一轮的梯度。

8.测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集总样本数num_batches = len(dataloader)  # 批次数量model.eval()  # 切换到评估模式(关闭 dropout/batchnorm等)test_loss, correct = 0, 0  # 总损失和正确预测数with torch.no_grad():  # 禁用梯度计算(节省内存,加速计算)for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)  # 前向传播(无反向传播)# 累加损失和正确数test_loss += loss_fn(pred, y).item()# 计算正确预测数:pred.argmax(1)取概率最大的类别索引correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batches  # 平均损失correct /= size  # 准确率(正确数/总样本数)print(f"测试结果:\n 准确率: {(100*correct):>0.1f}%, 平均损失: {test_loss:>8f}")
  • 与训练的区别
    • model.eval():关闭训练特有的层(如 Dropout),确保预测稳定。
    • torch.no_grad():禁用梯度计算,减少内存占用和计算时间。
    • 无参数更新:仅计算损失和准确率,不调整模型参数。

9.训练与评估主流程

# 初始化损失函数、优化器和学习率调度器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失(适用于分类任务)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam优化器(学习率0.001)
# 学习率调度器:每10个epoch学习率乘以0.5(衰减)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)# 训练20个epoch
epochs = 20
acc_s = []  # 可用于记录准确率(此处未使用)
loss_s = []  # 可用于记录损失(此处未使用)
for t in range(epochs):print(f"第{t+1}轮训练\n-------------------")train(train_dataloader, model, loss_fn, optimizer)  # 训练一轮scheduler.step()  # 更新学习率(每轮训练后调用)
print("训练完成!")
test(test_dataloader, model, loss_fn)  # 最终测试
  • 关键组件
    • 损失函数CrossEntropyLoss适用于多分类任务,内置 SoftMax 激活。
    • 优化器:Adam 是常用的自适应学习率优化器,收敛速度快。
    • 学习率调度器:避免学习率过大导致不收敛或过小导致收敛缓慢,这里每 10 轮衰减一半。
  • 执行流程:重复 20 次「训练一轮 + 更新学习率」,最后在测试集上评估模型性能。

二、数据增强

        当然通过我们模型训练,发现训练的结果非常的低,损失也很高,这是为什么呢?分析后发现是因为我们的数据集非常的少,那有什么办法能对数据进行增加呢?我们就可以通过数据增强来使得训练的数据变多(但是数据集不会变多)

1.整体结构

data_transforms = {     # 字典结构,分别存储训练集和验证集的转换策略'train': transforms.Compose([...]),  # 训练集的数据增强和预处理'valid': transforms.Compose([...])   # 验证集的预处理
}
  • 使用transforms.Compose将多个预处理操作组合成一个管道,图像会按顺序依次经过这些操作
  • 训练集和验证集使用不同的预处理策略:训练集通常需要数据增强来提高泛化能力,验证集则只需要基本预处理

2.训练集转换 ('train')

尺寸调整

transforms.Resize([300, 300]),  # 将图像大小调整为300×300像素
  • 先将图像放大到比最终需要的尺寸更大,为后续的裁剪操作预留空间

随机旋转

transforms.RandomRotation(45),  # 在-45度到45度之间随机旋转图像
  • 数据增强手段:通过随机旋转增加样本多样性,使模型对图像旋转变化更鲁棒
  • 旋转角度范围是 [-45, 45] 度的随机值

中心裁剪

transforms.CenterCrop(256),  # 从图像中心裁剪出256×256像素的区域
  • 在旋转后进行中心裁剪,得到固定大小的图像
  • 与 Resize 配合使用:先放大再裁剪,避免旋转后图像边缘出现黑边

随机水平翻转

transforms.RandomHorizontalFlip(p=0.5),  # 以50%的概率随机水平翻转图像
  • 数据增强手段:模拟左右方向变化,例如 "猫" 的图像左右翻转后依然是 "猫"
  • p=0.5表示有一半的概率会执行翻转,另一半概率不翻转

随机垂直翻转

transforms.RandomVerticalFlip(p=0.5),  # 以50%的概率随机垂直翻转图像
  • 数据增强手段:模拟上下方向变化,适用于对垂直方向不敏感的场景
  • 注意:某些场景不适合垂直翻转(如人像),但食物类图像通常适用

转换为张量

transforms.ToTensor(),  # 将PIL图像或NumPy数组转换为PyTorch张量
    • 转换后的数据格式为(C, H, W)(通道数 × 高度 × 宽度)
    • 同时会将像素值从 [0, 255] 范围归一化到 [0, 1] 范围

标准化

transforms.Normalize([0.485, 0.456, 0.486], [0.229, 0.224, 0.225])  # 标准化处理
  • 对每个通道进行标准化:output = (input - mean) / std
  • 这里使用的均值和标准差是 ImageNet 数据集的统计值,是计算机视觉中常用的预处理参数
  • 标准化的作用:使不同图像的像素值分布更一致,加速模型收敛

3.验证集转换 ('valid')

尺寸调整

transforms.Resize([256, 256]),  # 直接将图像调整为256×256像素
  • 验证集不需要数据增强,直接调整到模型输入需要的尺寸

转换为张量

transforms.ToTensor(),  # 与训练集相同,转换为PyTorch张量

标准化

transforms.Normalize([0.485, 0.456, 0.486], [0.229, 0.224, 0.225])  # 使用与训练集相同的均值和标准差
  • 必须使用与训练集完全相同的标准化参数,否则会导致数据分布不一致,影响模型预测

4.训练集与验证集处理差异的原因

  • 训练集:使用多种数据增强技术(旋转、翻转等),目的是增加样本多样性,防止模型过拟合,提高模型的泛化能力
  • 验证集:只进行必要的预处理(尺寸调整、标准化),不使用数据增强,目的是真实反映模型在测试数据上的表现

完整代码

data_transforms={     #字典'train':transforms.Compose([ #组合transforms.Resize([300,300]),# 图像变换大小transforms.RandomRotation(45),#图片旋转,45度到-45度之间随机旋转transforms.CenterCrop(256),# 从中心开始裁剪transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转,设置一个概率transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转# transforms.RandomGrayscale(p=0.1),#概率换成灰度值transforms.ToTensor(), #数据转换成ToTensortransforms.Normalize([0.485,0.456,0.486],[0.229,0.224,0.225])#归一化,均值,标准差]),'valid':transforms.Compose([transforms.Resize([256, 256]),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.486],[0.229,0.224,0.225])]),}

将上面的data_transforms换成数据增强的就行,其他代码不变。

通过数据增强,增加训练次数我们的模型真去了会有所提升(但是还是比较低,这只是一种提升训练数据量的一种方法,主要的还是要增加数据集)

http://www.dtcms.com/a/362872.html

相关文章:

  • SQLShift 实现Oracle 到 OceanBase 的存储过程转换初体验
  • FlowGPT-GPT提示词分享平台
  • 深入剖析Java设计模式之策略模式:从理论到实战
  • 【音视频】 WebRTC GCC 拥塞控制算法
  • 从Java全栈到前端框架:一场真实的技术面试实录
  • Leetcode二分查找(5)
  • 【算法】哈希表专题
  • 单元测试总结2
  • 【大前端】Vue 和 React 主要区别
  • dy图文批量下载
  • 【C++】模板(初阶)--- 初步认识模板
  • 从一行 var a = 1 开始,深入理解 V8 引擎的心脏
  • 【Linux我做主】进程退出和终止详解
  • 掌握设计模式--模板方法模式
  • 前缀树约束大语言模型解码
  • Ollama:本地大语言模型部署和使用详解
  • 【论文阅读】DeepSeek-LV2:用于高级多模态理解的专家混合视觉语言模型
  • ObjectMapper一个对象转json串为啥设计成注入?...
  • 【学Python自动化】 7. Python 输入与输出学习笔记
  • Pandas Python数据处理库:高效处理Excel/CSV数据,支持分组统计与Matplotlib可视化联动
  • 车载刷写架构 --- ECU软件更新怎么保证数据的正确性?
  • Ansible 循环、过滤器与判断逻辑
  • 【保姆级喂饭教程】把chrome谷歌浏览器中的插件导出为CRX安装包
  • Android init 实战项目
  • 文件页的预取逻辑
  • IAM(Identity and Access Management)
  • windows中使用cmd/powershell查杀进程
  • k8s的CRD自定义资源类型示例
  • 从全球视角到K8s落地的Apache IoTDB实战
  • 2025年新版C语言 模电数电及51单片机Proteus嵌入式开发入门实战系统学习,一整套全齐了再也不用东拼西凑