(补)CNN 模型搭建与训练:PyTorch 实战 CIFAR10 任务的应用
一、代码核心定位:承接训练,实现单图预测
前文CNN 模型搭建与训练:PyTorch 实战 CIFAR10 任务-CSDN博客
已完成 CIFAR10 模型的三大核心步骤:
- 定义了
Prayer卷积神经网络结构(model.py); - 完成了 10 轮训练,得到了
prayer_0.pth到prayer_9.pth等训练好的模型文件; - 验证了模型在测试集上的正确率最终达到约 55.6%。
而当前代码的核心目标是:
用训练好的模型(如prayer_29.pth),对一张自定义的图像(如dog.png)进行类别预测,
把 “离线训练的模型” 转化为 “可实时预测的工具”。
二、代码逐段详解:从图像到预测结果的全流程
1. 前置准备:导入库与定义模型
这部分是模型推理的 “基础保障”,确保代码能调用 PyTorch 工具和匹配训练时的模型结构。
import torch # PyTorch核心库,负责张量运算和模型推理
import torchvision # 提供图像预处理工具
from PIL import Image # 读取和处理图像的经典库
from torch import nn # 神经网络模块,用于定义模型结构
- 模型类
Prayer的重复定义:这里重新定义了与model.py完全一致的Prayer类,原因是torch.load加载完整模型时,需要当前环境中有对应的模型类定义(否则无法解析模型结构)。 - 核心是保证推理时的模型结构与训练时完全一致,从输入通道(3)、卷积 / 池化层级,到全连接层维度(最终输出 10 类),均和训练阶段完全匹配。
2. 图像预处理:让输入符合模型要求
CIFAR10 训练时,图像是 “32×32 像素的 RGB 彩色图 + 张量格式”,因此自定义图像必须经过相同预处理,否则模型无法识别。

-
步骤 1:读取图像
image = Image.open(image_path).convert('RGB'):- 用
PIL.Image读取图像文件; convert('RGB')强制转为 3 通道彩色图,避免灰度图(1 通道)或透明图(4 通道)导致通道数不匹配。
- 用
-
步骤 2:标准化预处理
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)), # 缩放到32×32,匹配训练输入尺寸torchvision.transforms.ToTensor() # 转为Tensor:像素值从[0,255]→[0,1],维度从(HWC)→(CHW)
])
image = transform(image) # 处理后形状:(3, 32, 32)
-
这一步是关键匹配项:如果图像尺寸、格式与训练数据不一致,模型会因输入维度错误直接报错。
3. 调整输入形状:适配模型的批量推理逻辑
- 模型训练时处理的是 “批量数据”(如
batch_size=64,输入形状为(64, 3, 32, 32)),即使推理单张图像,也需要调整为 “批量维度为 1” 的格式。 image = torch.reshape(image, (1, 3, 32, 32)):将(3, 32, 32)转为(1, 3, 32, 32),其中1代表 “当前批量只有 1 张图”。
模型在训练时已经 “习惯了” 接收4 个维度的输入(批量大小 + 通道 + 高 + 宽)。就像你去自动售货机买水,机器的投币口只接受 “竖着插卡”,如果你横着插,即使卡是对的,机器也不认 —— 模型也有这样的 “输入格式洁癖”。
比如:
- 模型的第一层是卷积层
nn.Conv2d(in_channels=3, ...),它要求输入必须是 4 维张量(批量大小 ×3×32×32); - 如果你直接输入单张图的 3 维张量
(3, 32, 32),模型会 “困惑”:“第一个维度应该是批量大小,怎么没有了?” 然后直接报错。
4. 模型加载与推理:核心预测环节
这部分是连接 “训练成果” 与 “预测结果” 的桥梁。
-
加载训练好的模型
-
model = torch.load("prayer_29.pth", map_location='cpu', weights_only=False):prayer_29.pth:训练保存的模型文件(前文训练 10 轮,此处文件名可能为示例,实际对应某一轮训练结果);map_location='cpu':指定在 CPU 上推理(无需 GPU 也能运行,兼容更多环境);weights_only=False:允许加载 “完整模型”(包含结构 + 权重),适配前文torch.save(prayer, ...)的保存方式。
-
切换模型为评估模式
model.eval():将模型从 “训练模式” 切换为 “评估模式”,关闭 Dropout(此处模型未用,但为通用规范)、固定 BatchNorm 等层的参数,确保推理结果稳定。 -
无梯度推理
with torch.no_grad():output = model(image) # 模型输出:(1, 10)的张量 with torch.no_grad():关闭梯度计算,减少内存占用、加快推理速度(推理阶段无需更新参数,梯度无用);output形状为(1, 10):对应 1 个样本、10 个类别的 “预测分数”(非概率,数值越大代表模型认为属于该类的可能性越高)。
5. 输出预测结果:解读模型输出
- 打印预测分数:
print(output)输出 10 个类别的原始分数, - 例如某类分数为
2.5,另一类为-1.2,分数越高概率越大。 - 打印预测类别索引:
print(output.argmax(1)):argmax(1):在 “类别维度”(第 1 维,对应 10 个类别)上取最大值的索引,结果为0-9中的一个;- 该索引对应 CIFAR10 的类别(如
0=飞机、1=汽车、3=猫、5=狗等,需对照 CIFAR10 类别表解读)。
CIFAR10 类别索引 - 名称映射表
| 类别索引 | 对应类别名称 | 英文名称 |
|---|---|---|
| 0 | 飞机 | airplane |
| 1 | 汽车 | automobile |
| 2 | 鸟 | bird |
| 3 | 猫 | cat |
| 4 | 鹿 | deer |
| 5 | 狗 | dog |
| 6 | 青蛙 | frog |
| 7 | 马 | horse |
| 8 | 船 | ship |
| 9 | 卡车 | truck |

