当前位置: 首页 > news >正文

(补)CNN 模型搭建与训练:PyTorch 实战 CIFAR10 任务的应用

一、代码核心定位:承接训练,实现单图预测

前文CNN 模型搭建与训练:PyTorch 实战 CIFAR10 任务-CSDN博客

已完成 CIFAR10 模型的三大核心步骤:

  1. 定义了Prayer卷积神经网络结构(model.py);
  2. 完成了 10 轮训练,得到了prayer_0.pthprayer_9.pth等训练好的模型文件;
  3. 验证了模型在测试集上的正确率最终达到约 55.6%。

而当前代码的核心目标是:

用训练好的模型(如prayer_29.pth),对一张自定义的图像(如dog.png)进行类别预测

把 “离线训练的模型” 转化为 “可实时预测的工具”。

二、代码逐段详解:从图像到预测结果的全流程

1. 前置准备:导入库与定义模型

这部分是模型推理的 “基础保障”,确保代码能调用 PyTorch 工具和匹配训练时的模型结构。

import torch                # PyTorch核心库,负责张量运算和模型推理
import torchvision          # 提供图像预处理工具
from PIL import Image       # 读取和处理图像的经典库
from torch import nn        # 神经网络模块,用于定义模型结构
  • 模型类Prayer的重复定义:这里重新定义了与model.py完全一致的Prayer类,原因是torch.load加载完整模型时,需要当前环境中有对应的模型类定义(否则无法解析模型结构)。
  • 核心是保证推理时的模型结构与训练时完全一致,从输入通道(3)、卷积 / 池化层级,到全连接层维度(最终输出 10 类),均和训练阶段完全匹配。

2. 图像预处理:让输入符合模型要求

CIFAR10 训练时,图像是 “32×32 像素的 RGB 彩色图 + 张量格式”,因此自定义图像必须经过相同预处理,否则模型无法识别。

  • 步骤 1:读取图像image = Image.open(image_path).convert('RGB')

    • PIL.Image读取图像文件;
    • convert('RGB')强制转为 3 通道彩色图,避免灰度图(1 通道)或透明图(4 通道)导致通道数不匹配。
  • 步骤 2:标准化预处理

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),  # 缩放到32×32,匹配训练输入尺寸torchvision.transforms.ToTensor()         # 转为Tensor:像素值从[0,255]→[0,1],维度从(HWC)→(CHW)
])
image = transform(image)  # 处理后形状:(3, 32, 32)
  • 这一步是关键匹配项:如果图像尺寸、格式与训练数据不一致,模型会因输入维度错误直接报错。


3. 调整输入形状:适配模型的批量推理逻辑

  • 模型训练时处理的是 “批量数据”(如batch_size=64,输入形状为(64, 3, 32, 32)),即使推理单张图像,也需要调整为 “批量维度为 1” 的格式。
  • image = torch.reshape(image, (1, 3, 32, 32)):将(3, 32, 32)转为(1, 3, 32, 32),其中1代表 “当前批量只有 1 张图”。

模型在训练时已经 “习惯了” 接收4 个维度的输入(批量大小 + 通道 + 高 + 宽)。就像你去自动售货机买水,机器的投币口只接受 “竖着插卡”,如果你横着插,即使卡是对的,机器也不认 —— 模型也有这样的 “输入格式洁癖”。

比如:

  • 模型的第一层是卷积层 nn.Conv2d(in_channels=3, ...),它要求输入必须是 4 维张量(批量大小 ×3×32×32);
  • 如果你直接输入单张图的 3 维张量 (3, 32, 32),模型会 “困惑”:“第一个维度应该是批量大小,怎么没有了?” 然后直接报错。

4. 模型加载与推理:核心预测环节

这部分是连接 “训练成果” 与 “预测结果” 的桥梁。

  • 加载训练好的模型

  • model = torch.load("prayer_29.pth", map_location='cpu', weights_only=False)

    • prayer_29.pth:训练保存的模型文件(前文训练 10 轮,此处文件名可能为示例,实际对应某一轮训练结果);
    • map_location='cpu':指定在 CPU 上推理(无需 GPU 也能运行,兼容更多环境);
    • weights_only=False:允许加载 “完整模型”(包含结构 + 权重),适配前文torch.save(prayer, ...)的保存方式。
  • 切换模型为评估模式model.eval():将模型从 “训练模式” 切换为 “评估模式”,关闭 Dropout(此处模型未用,但为通用规范)、固定 BatchNorm 等层的参数,确保推理结果稳定。

  • 无梯度推理

    with torch.no_grad():output = model(image)  # 模型输出:(1, 10)的张量
  • with torch.no_grad():关闭梯度计算,减少内存占用、加快推理速度(推理阶段无需更新参数,梯度无用);
  • output形状为(1, 10):对应 1 个样本、10 个类别的 “预测分数”(非概率,数值越大代表模型认为属于该类的可能性越高)。

5. 输出预测结果:解读模型输出

  • 打印预测分数print(output)输出 10 个类别的原始分数,
  • 例如某类分数为2.5,另一类为-1.2,分数越高概率越大。
  • 打印预测类别索引print(output.argmax(1))
    • argmax(1):在 “类别维度”(第 1 维,对应 10 个类别)上取最大值的索引,结果为0-9中的一个;
    • 该索引对应 CIFAR10 的类别(如0=飞机1=汽车3=猫5=狗等,需对照 CIFAR10 类别表解读)。

CIFAR10 类别索引 - 名称映射表

类别索引对应类别名称英文名称
0飞机airplane
1汽车automobile
2bird
3cat
4鹿deer
5dog
6青蛙frog
7horse
8ship
9卡车truck

http://www.dtcms.com/a/519475.html

相关文章:

  • spring篇:一文读懂spring:工作原理之核心技术解析
  • docker 原理
  • 龙岩网站开发较好的公司王战山
  • vllm论文中figure3每个块的区别
  • 西安营销网站建设公司厦门建设局官网
  • 机器视觉的锂电池叠片应用
  • Rhino(犀牛)转换为 3DXML 全指南:迪威模型网在线实操 + 本地方案
  • react报错Cannot find module ‘ajv/dist/compile/codegen‘
  • uv如何配置阿里云源在 pyproject.toml 中 或在 uv.toml 中
  • 【算法】排序算法汇总1
  • 学习笔记 | 图论基础
  • 苏州要服务网站建设视频网站建设多少钱
  • Flink 使用 RocksDB 作为状态后端存储的原因详解
  • 历经一载编程路,褪去青涩踏新程
  • 面试随想录4:吉贝克后端
  • 使用Python操作你的手机(Appium入门)
  • Spire.Doc 实践指南:将Word 文档转换为 XML
  • 【2B篇】阿里通义 Qwen3-VL 新增 2B、32B 两个模型尺寸,手机也能轻松运行
  • 目标检测YOLO实战应用案例100讲-基于多模态和多模型融合 的三维目标检测
  • 【成长纪实】从“Hello World”到分布式实战的进阶之路
  • 图论理论基础(1)
  • 开源 Linux 服务器与中间件(十)Mqtt协议和Emqx服务器安装测试
  • 网站建设实践鉴定手机网站建设讯息
  • 网站管理文档怎么写晚上睡不着看点害羞的东西app
  • uni-app 广告弹窗最佳实践:不扰民、可控制频次、含完整源码
  • 使用eNSP模拟器搭建网络拓扑结构(笔记2):从 0 到 1 掌握华为网络仿真
  • UniApp 多页面编译优化:编译时间从10分钟到1分钟
  • C++变量与函数命名规范技术指南 (基于华为编码规范与现代C++最佳实践)
  • ELK1——elasticsearch
  • 【图像卷积基础】卷积过程卷积实现通道扩充与压缩池化Pooling原理和可视化