当前位置: 首页 > news >正文

深度学习:自定义数据集处理、数据增强与最优模型管理

目录

一、整体流程概述

二、自定义数据集处理

1. 数据集结构设计

2. 数据集索引文件生成

三、自定义数据集加载类

四、数据增强策略

五、卷积神经网络模型构建

六、模型训练与最优模型管理

1. 训练与测试函数

2. 训练主流程

七、模型加载与推理


一、整体流程概述

深度学习项目通常遵循 "数据→模型→训练→部署" 的闭环流程,本文以食物识别任务为例,详解从自定义数据集处理到模型训练、优化及调用的完整流程。核心环节包括:

  1. 自定义数据集构建与预处理
  2. 数据增强策略设计
  3. 卷积神经网络 (CNN) 模型构建
  4. 模型训练与最优模型保存
  5. 训练后模型调用与推理

二、自定义数据集处理

1. 数据集结构设计

通常采用 "类别目录 + 图片文件" 的组织结构:

food_dataset/
├── train/
│   ├── 苹果/
│   │   ├── apple1.jpg
│   │   └── apple2.jpg
│   └── 香蕉/
│       └── ...
└── test/└── ...

2. 数据集索引文件生成

需要将图片路径与标签映射关系保存为文本文件(如train.txttest.txt),方便模型加载。

核心代码(食物识别 1 - 预处理.py)

import osdef train_test_file(root, dir):file_txt = open(dir + '.txt', 'w+', encoding='utf-8')file_canzhao = open('canzhao.txt', 'w+', encoding='utf-8')path = os.path.join(root, dir)dirs = []  # 存储类别目录名recorded_classes = set()  # 去重集合for roots, directories, files in os.walk(path):# 获取一级子目录作为类别if not dirs and roots == path:dirs = directoriesif files:  # 处理图片文件current_class = roots.split(os.sep)[-1]  # 提取类别名class_index = dirs.index(current_class)  # 分配类别索引# 写入图片路径和标签for file in files:img_path = os.path.join(roots, file)file_txt.write(f"{img_path} {class_index}\n")# 写入类别-索引映射(去重)if current_class not in recorded_classes:file_canzhao.write(f"{current_class} {class_index}\n")recorded_classes.add(current_class)file_txt.close()file_canzhao.close()# 生成训练集和测试集索引文件
root = r'..\\food_dataset'
train_test_file(root, 'train')
train_test_file(root, 'test')
..\\food_dataset\train\八宝粥\img_八宝粥罐_22.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_29.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_65.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_68.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_84.jpeg 0
..\\food_dataset\train\八宝粥\img_八宝粥罐_98.jpeg 0
..\\food_dataset\train\哈密瓜\img_水果_103.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_13.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_136.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_142.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_163.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_174.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_191.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_209.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_238.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_30.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_34.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_42.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_44.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_57.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_81.jpeg 1
..\\food_dataset\train\哈密瓜\img_水果_92.jpeg 1
..\\food_dataset\train\圣女果\img_圣女果_104.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_105.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_116.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_12.jpeg 2
..\\food_dataset\train\圣女果\img_圣女果_13.jpeg 2
略

知识点解析

  • os.walk():递归遍历目录结构,获取所有文件路径
  • os.sep:适配不同操作系统的路径分隔符(Windows 用\,Linux 用/
  • 类别索引映射:通过dirs.index(current_class)建立类别与数字索引的映射,便于模型处理

三、自定义数据集加载类

使用 PyTorch 的Dataset类封装数据集,实现数据的按需加载。

核心代码

from torch.utils.data import Dataset
from PIL import Image
import torch
import numpy as npclass food_dataset(Dataset):def __init__(self, file_path, transform=None):self.imgs = []self.labels = []self.transform = transform# 从索引文件读取数据with open(file_path, 'r', encoding='utf-8') as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)  # 返回样本总数def __getitem__(self, idx):# 读取图片并应用转换image = Image.open(self.imgs[idx]).convert('RGB')  # 确保RGB格式if self.transform:image = self.transform(image)# 标签转换为Tensorlabel = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))return image, label

知识点解析

  • Dataset抽象类:必须实现__len__(返回样本数)和__getitem__(获取单个样本)方法
  • 延迟加载:仅在需要时才读取图片,节省内存
  • 数据转换接口:通过transform参数灵活接入数据增强 pipeline

四、数据增强策略

数据增强是提升模型泛化能力的关键手段,通过对训练数据进行随机变换,增加数据多样性。

核心代码(食物识别 1 - 数据增强.py)

from torchvision import transformsdata_transforms = {'train': transforms.Compose([transforms.Resize((300, 300)),  # 缩放至300x300transforms.RandomRotation(45),  # 随机旋转(-45°~45°)transforms.CenterCrop(256),  # 中心裁剪至256x256transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻转transforms.ColorJitter(  # 颜色抖动brightness=0.2,  # 亮度contrast=0.1,    # 对比度saturation=0.1,  # 饱和度hue=0.1          # 色调),transforms.RandomGrayscale(p=0.1),  # 10%概率转为灰度图transforms.ToTensor(),  # 转为Tensor(0-1范围)transforms.Normalize(  # 标准化[0.485, 0.456, 0.406],  # 均值[0.229, 0.224, 0.225]   # 标准差)]),'valid': transforms.Compose([  # 验证集仅做必要转换transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}

知识点解析

  • 训练集增强原则:保留语义信息的同时增加多样性(旋转、翻转、颜色变化等)
  • 验证集处理:仅做必要的尺寸调整和标准化,保证评估一致性
  • 标准化(Normalize):使用 ImageNet 数据集的均值和标准差,使输入分布更稳定
  • Compose:将多个变换组合成一个 pipeline,按顺序执行

五、卷积神经网络模型构建

设计适用于食物识别的 CNN 模型,通过卷积层提取视觉特征,全连接层完成分类。

核心代码

import torch
from torch import nnclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积块1:特征提取+降维self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3,  # 输入通道数(RGB)out_channels=16,  # 卷积核数量kernel_size=5,  # 卷积核大小5x5stride=1,  # 步长padding=2  # 填充,保持尺寸),nn.ReLU(),  # 激活函数nn.MaxPool2d(kernel_size=2)  # 2x2池化,尺寸减半)# 卷积块2:加深特征提取self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),nn.ReLU())# 卷积块3:进一步提取+降维self.conv3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 卷积块4:高级特征提取self.conv4 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU())# 全连接层:分类输出self.line = nn.Linear(128 * 64 * 64, 20)  # 20类食物def forward(self, x):# 前向传播路径x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = self.conv4(x)x = x.view(x.size(0), -1)  # 展平特征图out = self.line(x)return out

知识点解析

  • 卷积层作用:通过滑动窗口提取局部特征(边缘、纹理、形状等)
  • 池化层作用:降低特征图尺寸,减少参数数量,增强平移不变性
  • 激活函数(ReLU):引入非线性,使模型能拟合复杂特征关系
  • 全连接层:将卷积提取的高维特征映射到类别空间

六、模型训练与最优模型管理

1. 训练与测试函数

# 训练函数
def train(dataloader, model, loss_fn, optimizer):model.train()  # 训练模式(启用dropout等)batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)  # 数据移至计算设备pred = model(X)  # 前向传播loss = loss_fn(pred, y)  # 计算损失# 反向传播与参数更新optimizer.zero_grad()  # 清空梯度loss.backward()  # 计算梯度optimizer.step()  # 更新参数# 打印训练进度if batch_size_num % 100 == 1:print(f"Loss: {loss.item():.4f} [batch:{batch_size_num}]")batch_size_num += 1# 测试函数(含最优模型保存)
best_acc = 0
def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)correct = 0loss_sum = 0model.eval()  # 评估模式(关闭dropout等)with torch.no_grad():  # 关闭梯度计算for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)loss_sum += loss.item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()acc = correct / sizelatest_loss = loss_sum / num_batchesprint(f"Accuracy: {(acc * 100)}%, Loss: {latest_loss:.4f}")# 保存最优模型if acc > best_acc:best_acc = acctorch.save(model, 'best1.pt')  # 保存完整模型print(f"保存最优模型,准确率:{best_acc*100:.2f}%")

2. 训练主流程

# 设备选择
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f'Using device: {device}')# 加载数据集
train_data = food_dataset('train.txt', data_transforms['train'])
test_data = food_dataset('test.txt', data_transforms['valid'])
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()  # 多分类损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)  # Adam优化器# 执行训练
epochs = 100
print('训练开始')
for epoch in range(epochs):print(f"Epoch {epoch + 1}/{epochs}")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print(f'训练结束, 最佳准确率:{best_acc*100:.2f}%')

知识点解析

  • 设备加速:自动选择 GPU(CUDA/MPS)或 CPU 进行计算
  • 训练模式与评估模式:model.train()model.eval()控制 dropout、BN 层等行为
  • 梯度管理:optimizer.zero_grad()避免梯度累积,with torch.no_grad()节省评估时内存
  • 最优模型保存:通过跟踪验证集准确率,只保存表现最好的模型,避免过拟合模型

七、模型加载与推理

训练完成后,加载最优模型进行实际预测。

核心代码(食物识别 1 - 调用最优模型.py)

# 加载模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model = torch.load('best1.pt', map_location=device)  # 加载完整模型
model.eval()  # 切换到评估模式# 加载类别映射
def load_class_mapping(canzhao_path='canzhao.txt'):index_to_name = {}with open(canzhao_path, 'r', encoding='utf-8') as f:for line in f.readlines():food_name, index = line.strip().split(' ')index_to_name[int(index)] = food_namereturn index_to_name# 单张图片预测
def predict_image(image_path, model, index_to_name):try:# 图片预处理(与验证集一致)transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0)  # 增加批次维度image = image.to(device)# 推理with torch.no_grad():pred = model(image)pred_index = pred.argmax(1).item()  # 获取最高概率类别return index_to_name.get(pred_index, "未知类别")except Exception as e:return f"预测失败:{str(e)}"# 交互预测
index_to_name = load_class_mapping()
while True:img_path = input("请输入图片路径(q退出):")if img_path.lower() == 'q':breakprint("预测结果:", predict_image(img_path, model, index_to_name))

知识点解析

  • 模型加载:torch.load()加载保存的模型,map_location适配不同计算设备
  • 推理预处理:必须与训练时的验证集处理完全一致,否则会导致分布不匹配
  • 批次维度:模型输入需为(batch_size, channel, height, width),通过unsqueeze(0)添加批次维度
  • 类别映射:将模型输出的数字索引转换为实际类别名称

文章转载自:

http://xeOffRKJ.kbgzj.cn
http://1u2HL8ku.kbgzj.cn
http://oca8EwWG.kbgzj.cn
http://5uL3WNfy.kbgzj.cn
http://EMf2UrMF.kbgzj.cn
http://XBRzDL6a.kbgzj.cn
http://rKc7WjWN.kbgzj.cn
http://3oSiptI2.kbgzj.cn
http://OCMakmny.kbgzj.cn
http://YLreX827.kbgzj.cn
http://JwVjYxRQ.kbgzj.cn
http://PlTc5cUQ.kbgzj.cn
http://eVz5AO6k.kbgzj.cn
http://YxCBTHAq.kbgzj.cn
http://S1GGVQ3A.kbgzj.cn
http://tglGTKrX.kbgzj.cn
http://ExjNKYgI.kbgzj.cn
http://3iZ3pPVI.kbgzj.cn
http://i3ZL27rV.kbgzj.cn
http://beBPXR8z.kbgzj.cn
http://yfCqGfH6.kbgzj.cn
http://65vS77dQ.kbgzj.cn
http://GQdSEGt5.kbgzj.cn
http://FBdriD81.kbgzj.cn
http://WNFwkD80.kbgzj.cn
http://xVV2xAHD.kbgzj.cn
http://By7JqAxC.kbgzj.cn
http://2nsbVKwW.kbgzj.cn
http://qP0uMtcw.kbgzj.cn
http://TXbqerU1.kbgzj.cn
http://www.dtcms.com/a/367763.html

相关文章:

  • ASRPRO语音模块
  • 一个开源的企业官网简介
  • Linux的权限详解
  • 【ICCV 2025 顶会论文】,新突破!卷积化自注意力 ConvAttn 模块,即插即用,显著降低计算量和内存开销。
  • HTB Jerry
  • 微信支付--在线支付实战,引入Swagger,定义统一结果,创建并连接数据库
  • 为什么串口发送一串数据时需要延时?
  • 决策树算法详解:从原理到实战
  • 生成式AI优化新纪元:国产首个GEO工具的技术架构剖析
  • 2025年高教社杯全国大学生数学建模竞赛B题思路(2025数学建模国赛B题思路)
  • 【C语言】第一课 环境配置
  • git命令行打patch
  • day2today3夏暮客的Python之路
  • 随时学英语5 逛生活超市
  • Web相关知识(草稿)
  • 计算机组成原理:GPU架构、并行计算、内存层次结构等
  • 用服务器搭 “私人 AI 助手”:不用联网也能用,支持语音对话 / 文档总结(教程)
  • 学生时间管理系统设计与实现(代码+数据库+LW)
  • 【3D 入门-6】大白话解释 SDF(Signed Distance Field) 和 Marching Cube 算法
  • 并发编程——17 CPU缓存架构详解高性能内存队列Disruptor实战
  • Pycharm终端pip install的包都在C:\Users\\AppData\Roaming\Python\解决办法
  • Linux中用于线程/进程同步的核心函数——`sem_wait`函数
  • Day2p2 夏暮客的Python之路
  • C++虚函数虚析构函数纯虚函数的使用说明和理解
  • Process Explorer 学习笔记(第三章3.1.1):度量 CPU 的使用情况详解
  • 机器学习入门,第一个MCP示例
  • Spring Boot项目中MySQL索引失效的常见场景与解决方案
  • 2025 年高教社杯全国大学生数学建模竞赛C 题 NIPT 的时点选择与胎儿的异常判定 完整成品思路模型代码分享,全网首发高质量!!!
  • 代码随想录学习摘抄day6(二叉树1-11)
  • 吴恩达机器学习(五)