当前位置: 首页 > news >正文

深度学习——基于卷积神经网络实现食物图像分类之(保存最优模型)

引言

本文将详细介绍如何使用PyTorch框架构建一个完整的食物图像分类系统,包含数据预处理、模型构建、训练优化以及模型保存等关键环节。与上一篇博客介绍的版本相比,本版本增加了模型保存与加载功能,并优化了测试评估流程。

一、项目概述

本项目的目标是构建一个能够识别20种不同食物的图像分类系统。主要技术特点包括:

  1. 简化但高效的数据预处理流程
  2. 三层CNN网络架构设计
  3. 训练过程中自动保存最佳模型
  4. 完整的训练-评估流程实现

二、环境配置

首先确保已安装必要的Python库:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

三、数据预处理

3.1 数据转换设置

我们为训练集和验证集定义了不同的转换策略:

data_transforms = {'train': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),
}

简化说明

3.2 数据集准备
def train_test_file(root, dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots, directories, files in os.walk(path):if len(directories) != 0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()

该函数会生成包含图像路径和标签的文本文件,格式为:

path/to/image1.jpg 0
path/to/image2.jpg 1
...

四、自定义数据集类

我们继承PyTorch的Dataset类实现自定义数据集:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

关键改进

五、CNN模型架构

我们设计了一个三层CNN网络:

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)return self.out(x)

架构特点

  1. 每层包含卷积、ReLU激活和最大池化
  2. 使用padding保持特征图尺寸
  3. 最后通过全连接层输出分类结果

六、训练与评估流程

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()if batch_size_num % 1 == 0:print(f"loss: {loss.item():>7f} [batch:{batch_size_num}]")batch_size_num += 1
6.2 评估与模型保存
best_acc = 0def Test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= size# 保存最佳模型if correct > best_acc:best_acc = correcttorch.save(model.state_dict(), "best_model.pth")print(f"\n测试结果: \n 准确率:{(100*correct):.2f}%, 平均损失:{test_loss:.4f}")

关键改进

  1. 增加全局变量best_acc跟踪最佳准确率
  2. 实现两种模型保存方式:(1)只保存模型参数(state_dict)(2)保存整个模型
  3. 更详细的测试结果输出

七、完整训练流程

# 初始化
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练循环
epochs = 10
for t in range(epochs):print(f"Epoch {t+1}\n{'-'*20}")train(train_dataloader, model, loss_fn, optimizer)# 最终评估
Test(test_dataloader, model, loss_fn)

八、模型保存与加载

8.1 保存模型
# 方法1:只保存参数
torch.save(model.state_dict(), "model_params.pth")# 方法2:保存完整模型
torch.save(model, "full_model.pt")
8.2 加载模型
# 方法1对应加载方式
model = CNN().to(device)
model.load_state_dict(torch.load("model_params.pth"))# 方法2对应加载方式
model = torch.load("full_model.pt").to(device)

九、优化建议

  1. 数据增强:添加更多变换提高模型泛化能力
  2. 学习率调度:使用torch.optim.lr_scheduler动态调整学习率

http://www.dtcms.com/a/364213.html

相关文章:

  • leetcode-每日一题-人员站位的方案数-C语言
  • 基于飞算JavaAI的在线图书借阅平台设计与实现
  • 基于单片机雏鸡孵化恒温系统/孵化环境检测系统设计
  • GPIO的8种工作方式
  • 安装wsl报错0x800701bc
  • OCR识别在媒资管理系统的应用场景剖析与选择
  • 今天我们继续学习shell编程语言的内容
  • 数据结构之单链表的应用(一)
  • 【游戏开发】街景风格化运用到游戏中,一般有哪些风格可供选择?
  • ThreadLocal深度解析:线程本地存储的奥秘
  • 【模型学习】LoRA的原理,及deepseek-vl2下LoRA实现
  • 【渗透测试】使用 UV 简化 Python 工具和脚本管理
  • TypeScript:unknown 类型
  • 博维智航(彭州)——面试
  • C++高频误区:vector对象到底在堆上还是栈上?
  • flume扩展实战:自定义拦截器、Source 与 Sink 全指南
  • 博主必备神器~
  • 解锁复杂工作流:Roo Code 中的「Boomerang Tasks」机制 : Orchestrator Mode 的使用
  • 用好AI,从提示词工程到上下文工程
  • ARM - GPIO 标准库开发
  • 算法模板(Java版)_非负整数的高精度运算
  • Linux之Shell编程(五)命令工具与sed编辑
  • Java代码耗时统计的5种方法
  • 将 .vcproj 文件转换为 .pro 文件
  • Apache Doris:重塑湖仓一体架构的高效计算引擎
  • 常见机械机构的图graph表示
  • 【硬件测试】基于FPGA的16PSK+卷积编码Viterbi译码硬件片内测试,包含帧同步,信道,误码统计,可设置SNR
  • 新手也能懂的 MySQL 大表优化:40 字段表的规划思路 + 头表行表应用详解
  • Java8特性
  • MyBatis-Plus 实现用户分页查询(支持复杂条件)