当前位置: 首页 > news >正文

用 PyTorch 搭建 CIFAR10 线性分类器:从数据加载到模型推理全流程解析

在深度学习入门过程中,图像分类是最经典的任务之一,而 CIFAR10 数据集则是入门图像分类的 "练手神器"。

一、前置知识:CIFAR10 数据集是什么?

1.1 CIFAR10 核心参数

CIFAR10(Canadian Institute for Advanced Research 10)是由加拿大高级研究所发布的图像分类数据集,包含10 个类别的彩色图像,具体信息如下:

类别包含内容样本数量(训练集)样本数量(测试集)
0飞机(airplane)50001000
1汽车(automobile)50001000
2鸟类(bird)50001000
3猫(cat)50001000
4鹿(deer)50001000
5狗(dog)50001000
6青蛙(frog)50001000
7马(horse)50001000
8船(ship)50001000
9卡车(truck)50001000

1.2 图像尺寸与格式

CIFAR10 的每张图像都是3 通道彩色图像(RGB),尺寸固定为32×32 像素(即高度 32、宽度 32)。

  • 通道数(C):3(R 红、G 绿、B 蓝)
  • 高度(H):32
  • 宽度(W):32
  • 单张图像展平后特征数:3×32×32 = 3072(这是后续线性层输入维度的关键依据)

二、代码拆解:从导入库到数据加载

代码的第一部分是 "数据准备",核心是将 CIFAR10 数据集加载到 PyTorch 中,并按批次处理。我们逐行解析:

2.1 导入必要库

import torch          # PyTorch核心库(张量操作、自动求导等)
import torchvision    # PyTorch视觉库(数据集、图像变换、预训练模型等)
from torch import nn  # PyTorch神经网络模块(含线性层、卷积层等)
from torch.utils.data import DataLoader  # 数据加载器(按批次加载数据)

这是 PyTorch 视觉任务的 "标准开头",每个库的作用必须明确:

  • torch:所有操作的基础,比如张量(Tensor)的创建和计算。
  • torchvision:专门为计算机视觉设计,提供了 CIFAR10 等常用数据集,以及图像预处理工具。
  • torch.nn:搭建神经网络的核心,比如nn.Linear(线性层)、nn.Conv2d(卷积层)都在这里。
  • DataLoader:将数据集按批次分割,支持多线程加载,是训练时高效喂数据的关键。

2.2 加载 CIFAR10 测试集

dataset = torchvision.datasets.CIFAR10(root='./data',          # 数据集保存路径(当前目录下的data文件夹)train=False,            # 是否为训练集:False表示加载测试集,True表示加载训练集download=True,          # 如果root路径下没有数据集,是否自动下载transform=torchvision.transforms.ToTensor()  # 图像变换:将PIL图像转为Tensor
)

(1)train=False的意义

  • train=True时,加载的是50000 张图像的训练集(用于模型训练);
  • train=False时,加载的是10000 张图像的测试集(用于验证模型性能);
  • 我们这里用测试集做演示,后续实际训练时需要切换为train=True

(2)transform=ToTensor()的作用

图像在计算机中原始存储格式是PIL 图像(或 numpy 数组),像素值范围是[0, 255](整数),但 PyTorch 模型要求输入是Tensor 格式,且像素值归一化到[0, 1](浮点数)。ToTensor()做了两件事:

  1. 将 PIL 图像转为形状为[C, H, W]的 Tensor(注意:PIL 图像默认是[H, W, C],这里会自动转置通道顺序);
  2. 将像素值从[0, 255]除以 255,归一化到[0, 1]

举个例子:一张 PIL 格式的 CIFAR10 图像(32×32×3),经过ToTensor()后会变成[3, 32, 32]的 Tensor,每个元素值在 0~1 之间。

2.3 用 DataLoader 按批次加载数据

dataloader = DataLoader(dataset, batch_size=64)

DataLoader的核心作用是将dataset(10000 张测试集图像)按batch_size=64分割成多个批次,方便模型批量处理(批量处理能提高计算效率,且符合梯度下降的原理)。

  • 总批次数量:10000 ÷ 64 ≈ 157(最后一个批次不足 64 张,实际为 10000 - 156×64 = 16 张);
  • 每个批次的数据格式:(imgs, targets),其中imgs是图像张量,targets是类别标签张量。

三、核心:搭建线性分类器(Prayer 类)

这部分是神经网络的 "骨架",我们用线性层(全连接层) 搭建一个最简单的分类器,理解模型的输入、输出和前向传播过程。

3.1 类的定义与初始化(__init__方法)

class Prayer(nn.Module):def __init__(self):super(Prayer, self).__init__()  # 继承nn.Module的初始化方法# 定义线性层:输入维度3072,输出维度10self.linear1 = nn.Linear(3072, 10)

这里有三个必须掌握的关键点:

(1)继承nn.Module的意义

nn.Module是 PyTorch 中所有神经网络模块的基类,自定义模型必须继承它。它的核心作用包括:

  • 自动管理模型中的可训练参数(比如线性层的权重和偏置);
  • 支持前向传播(forward方法)和反向传播(自动求导);
  • 提供模型保存、加载、移动到 GPU 等便捷功能。

(2)super(Prayer, self).__init__()的作用

这行代码是 "子类调用父类初始化方法" 的标准写法,目的是让父类nn.Module完成自身的初始化(比如初始化参数列表、计算设备等)。如果不写这行,模型会缺少必要的属性,后续调用时会报错。

(3)线性层nn.Linear(3072, 10)的参数含义

nn.Linear(in_features, out_features)是线性层的定义,本质是实现一个线性变换:y = x × W + b,其中:

  • in_features(输入维度):3072 → 对应 CIFAR10 图像展平后的特征数(3×32×32);
  • out_features(输出维度):10 → 对应 CIFAR10 的 10 个类别(每个输出值代表模型对该类别的 "置信度");
  • 线性层的可训练参数:
    • 权重W:形状为[out_features, in_features] → 这里是[10, 3072]
    • 偏置b:形状为[out_features] → 这里是[10]

3.2 前向传播(forward 方法)

def forward(self, input):output = self.linear1(input)  # 将输入传入线性层,得到输出return output

forward方法是模型的 "计算流程",定义了数据如何从输入经过模型层得到输出。在 PyTorch 中,不需要手动调用forward方法,只需将模型实例当作函数调用(比如prayer(output)),PyTorch 会自动触发forward方法。

举个例子:如果输入是一个形状为[64, 3072]的张量(64 个样本,每个样本 3072 个特征),经过self.linear1后,输出会是[64, 10]的张量(64 个样本,每个样本 10 个类别置信度)。

四、模型推理:数据流过模型的完整流程

代码的最后一部分是 "模型推理",即让加载好的批次数据通过模型,观察数据形状的变化(这是理解模型是否正确的关键)。

4.1 创建模型实例

prayer = Prayer()  # 实例化Prayer类,得到模型对象prayer

这行代码会调用Prayer类的__init__方法,创建线性层并初始化权重和偏置(默认是随机初始化)。此时prayer就是一个可使用的线性分类器模型。

4.2 遍历 DataLoader,执行推理

for data in dataloader:imgs, targets = data  # 拆分每个批次的数据:图像张量和标签张量print("原始图像形状:", imgs.shape)  # 打印原始图像形状# 展平操作:从第1维开始展平,保留批次维度output = torch.flatten(imgs, start_dim=1)print("展平后形状:", output.shape)  # 打印展平后形状output = prayer(output)  # 将展平后的特征传入模型,得到输出print("模型输出形状:", output.shape)  # 打印模型输出形状

我们逐句解析,并结合可视化图表理解数据形状的变化:

(1)原始图像形状:imgs.shape

每个批次的imgs是一个 4 维张量,形状为[batch_size, C, H, W]

  • batch_size=64时,形状为[64, 3, 32, 32]
  • 含义:64 张图像,每张图像 3 个通道,每个通道 32×32 像素。

可视化如下(用简化的维度图表示):

原始图像张量:[64(批次), 3(通道), 32(高度), 32(宽度)]
├─ 第1张图:[3, 32, 32]
├─ 第2张图:[3, 32, 32]
├─ ...
└─ 第64张图:[3, 32, 32]

2)展平操作:torch.flatten(imgs, start_dim=1)

线性层nn.Linear要求输入是2 维张量[batch_size, in_features]),而原始imgs是 4 维张量,因此需要用torch.flatten将其展平(只保留批次维度,将通道、高度、宽度合并为 "特征维度")。

  • start_dim=1:表示从第 1 个维度(通道维度)开始展平,第 0 个维度(批次维度)保持不变;
  • 展平后形状:[64, 3×32×32] = [64, 3072]

可视化展平过程:

原始形状:[64, 3, 32, 32]↓ 展平维度1~3(3×32×32=3072)
展平后形状:[64, 3072]
├─ 第1个样本:[3072个特征值](R通道32×32 + G通道32×32 + B通道32×32)
├─ 第2个样本:[3072个特征值]
├─ ...
└─ 第64个样本:[3072个特征值]

(3)模型输出形状:prayer(output).shape

将展平后的[64, 3072]张量传入模型,经过线性层nn.Linear(3072, 10)变换后,输出形状为[64, 10]

  • 含义:64 个样本,每个样本对应 10 个数值(分别代表模型对 10 个类别的置信度);
  • 后续步骤(未在代码中体现):通过torch.argmax(output, dim=1)取每个样本置信度最大的索引,即为模型预测的类别。

可视化模型输入输出:

模型输入(展平后):[64, 3072]↓ 经过线性变换 y = x×W + b(W: [10,3072], b: [10])
模型输出:[64, 10]
├─ 第1个样本:[置信度0, 置信度1, ..., 置信度9] → 预测类别=置信度最大的索引
├─ 第2个样本:[置信度0, 置信度1, ..., 置信度9]
├─ ...
└─ 第64个样本:[置信度0, 置信度1, ..., 置信度9]

4.3 实际运行输出结果

当你运行代码时,会看到如下输出(前两个批次为例):

原始图像形状: torch.Size([64, 3, 32, 32])
展平后形状: torch.Size([64, 3072])
模型输出形状: torch.Size([64, 10])
原始图像形状: torch.Size([64, 3, 32, 32])
展平后形状: torch.Size([64, 3072])
模型输出形状: torch.Size([64, 10])
...
# 最后一个批次(不足64张)
原始图像形状: torch.Size([16, 3, 32, 32])
展平后形状: torch.Size([16, 3072])
模型输出形状: torch.Size([16, 10])

这个结果验证了模型和数据处理的正确性:每个批次的输入都能顺利通过模型,输出形状符合预期。

五、常见问题与拓展:让代码更完整

虽然当前代码能正常运行,但它只是 "推理流程",实际深度学习项目还需要训练、损失计算、评估等步骤。

5.1 为什么线性层输入维度不能是 196608?

在之前的错误中,曾将线性层输入维度设为 196608,导致RuntimeError: mat1 and mat2 shapes cannot be multiplied。原因是:

  • 196608 = 64×3×32×32 → 这是整个批次所有像素的总数(包含了批次维度);
  • 线性层需要的是单个样本的特征数(3072),而不是整个批次的总像素数;
  • 记住:线性层输入维度 = 单样本特征数,与批次大小无关。

5.2 如何添加训练逻辑?

当前代码只有推理,要让模型能学习,需要添加损失函数、优化器和训练循环:

# 1. 定义损失函数(分类任务常用交叉熵损失)
loss_fn = nn.CrossEntropyLoss()
# 2. 定义优化器(常用Adam优化器,学习率0.001)
optimizer = torch.optim.Adam(prayer.parameters(), lr=0.001)
# 3. 训练循环(以10轮训练为例)
epochs = 10
for epoch in range(epochs):running_loss = 0.0  # 记录每轮的总损失prayer.train()  # 切换模型为训练模式(启用 dropout、批量归一化等训练特有的操作)for data in dataloader:imgs, targets = data# 步骤1:前向传播(数据过模型)output = torch.flatten(imgs, start_dim=1)pred = prayer(output)# 步骤2:计算损失(预测值与真实标签的差距)loss = loss_fn(pred, targets)# 步骤3:反向传播(计算梯度)optimizer.zero_grad()  # 清空上一轮的梯度(避免梯度累积)loss.backward()  # 从损失值反向计算各参数的梯度# 步骤4:参数更新(用梯度优化器更新模型权重)optimizer.step()# 累加损失值(用于打印日志)running_loss += loss.item() * imgs.size(0)  # loss.item()是单批次损失,乘以批次大小得到总损失# 计算每轮的平均损失epoch_loss = running_loss / len(dataset)print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}")

5.3 结果可视化:用 Matplotlib 展示预测效果​

为了更直观地理解模型的预测结果,我们可以用 Matplotlib 绘制 “图像 - 真实标签 - 预测标签” 的对应图,

六、完整代码

import matplotlib.pyplot as plt
import numpy as np# CIFAR10类别名称(与索引0-9对应)
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')def show_predictions(model, dataloader, num_images=5):model.eval()with torch.no_grad():# 取第一个批次的数据data_iter = iter(dataloader)imgs, targets = next(data_iter)# 前向传播得到预测结果output = torch.flatten(imgs, start_dim=1)pred = model(output)_, predicted = torch.max(pred, dim=1)# 转换图像格式(从[C, H, W]转为[H, W, C],方便Matplotlib显示)imgs = imgs.permute(0, 2, 3, 1).numpy()  # permute调整维度顺序imgs = imgs * 255  # 从[0,1]反归一化到[0,255](Matplotlib需要整数像素值)imgs = imgs.astype(np.uint8)  # 转为整数类型# 绘制图像plt.figure(figsize=(12, 4))for i in range(num_images):plt.subplot(1, num_images, i+1)plt.imshow(imgs[i])# 标题格式:真实标签 -> 预测标签(正确标绿,错误标红)true_label = classes[targets[i]]pred_label = classes[predicted[i]]color = 'green' if true_label == pred_label else 'red'plt.title(f"True: {true_label}\nPred: {pred_label}", color=color)plt.axis('off')  # 隐藏坐标轴plt.show()# 调用可视化函数
show_predictions(prayer, dataloader, num_images=5)

可视化效果说明:​

运行代码后,会显示 5 张 CIFAR10 测试集图像,每张图像下方标注 “真实类别” 和 “预测类别”:​

  • 若预测正确,标题为绿色;​
  • 若预测错误,标题为红色。

例如:​

  • 真实类别是 “cat”,预测类别也是 “cat” → 绿色标题;​
  • 真实类别是 “dog”,预测类别是 “cat” → 红色标题。​

通过可视化,你可以快速发现模型擅长预测哪些类别(如 “airplane”“ship” 这类轮廓清晰的类别),以及容易混淆的类别(如 “cat” 和 “dog” 这类细节相似的类别)。​

七、常见问题与解决方案(FAQ)​

在实际运行代码时,你可能会遇到以下问题,这里提前给出解决方案:​

常见问题​

错误原因​

解决方案​

RuntimeError: CUDA out of memory​

显卡内存不足(模型或批次太大)​

1. 减小batch_size(如从 64 改为 32、16);2. 使用torch.cuda.empty_cache()清空缓存;3. 改用 CPU 训练(速度慢但不占显存)​

训练损失不下降,准确率始终 10% 左右​

模型未学习(可能是梯度消失或学习率不合适)​

1. 调整学习率(如从 0.001 改为 0.01 或 0.0001);2. 检查数据预处理是否正确(如是否忘记归一化);3. 增加训练轮次(epochs)​

评估准确率远低于训练准确率​

模型过拟合(在训练集上表现好,测试集上表现差)​

1. 增加训练数据(如数据增强,见下文拓展);2. 减少模型复杂度(如线性层改为更简单的结构);3. 添加正则化(如 L2 正则化)​

八、 进阶拓展:数据增强提升模型性能​

当前代码使用的是原始 CIFAR10 图像,若想进一步提升模型准确率,可以添加数据增强(通过随机变换图像,增加训练数据的多样性,减少过拟合)。修改数据加载代码如下:

# 定义数据增强变换(训练集用增强,测试集不用)
train_transform = torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32, padding=4),  # 随机裁剪( padding=4表示先填充4像素,再裁剪32×32)torchvision.transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转(50%概率)torchvision.transforms.ToTensor()  # 转为Tensor并归一化
])test_transform = torchvision.transforms.ToTensor()  # 测试集只做归一化,不做增强# 加载训练集(用增强变换)
train_dataset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)# 加载测试集(不用增强)
test_dataset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=test_transform
)# 创建DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)  # 训练集打乱顺序
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)   # 测试集不打乱

数据增强的核心作用:​

  • 让模型看到更多 “变种” 图像(如裁剪后的局部图像、翻转后的图像);​
  • 避免模型过度依赖图像的固定位置或方向(如只认识向左的猫,不认识向右的猫);​
  • 通常能将 CIFAR10 线性分类器的准确率提升 10%-15%。
http://www.dtcms.com/a/494480.html

相关文章:

  • 什么是机械设备制造ERP?哲霖软件如何助力企业实现降本增效?
  • 【小白笔记】关于 Python 类、初始化以及 PyTorch 数据处理的问题
  • HTTPS 内容抓取实战 能抓到什么、怎么抓、不可解密时如何定位(面向开发与 iOS 真机排查)
  • Gartner发布数据安全态势管理市场指南:将功能扩展到AI的特定数据安全保护是DSPM发展方向
  • 建站系统的应用场景一条龙搭建网站
  • 公司网站自己做的网站怎么被搜录
  • item_video:获得淘宝商品视频 API 接口实战演示说明
  • appium学习
  • [Linux]学习笔记系列 -- [kernel][irq]softirq
  • 家庭相册私有化:Immich+cpolar构建你的数字记忆堡垒
  • 存储同步管理器SyncManager 归纳
  • 做游戏网站多少钱建设电子商务网站要多少钱
  • iBizModel 实体通知(PSDENOTIFY)模型详解
  • mysql函数大全及举例
  • 人工智能综合项目开发3-----农业病虫害识别dataclean.py
  • R语言手搓一个计算生存分析C指数(C-index)的函数算法
  • 使用leaflet库加载服务器离线地图瓦片(这边以本地nginx服务器为例)
  • 无状态协议HTTP/HTTPS (笔记)
  • 模式识别与机器学习课程笔记(8):特征提取与选择
  • python+uniapp基于微信美食点餐系统小程序
  • 【邀请函】锐成信息 × Sectigo | CLM - SSL 证书自动化运维解决方案发布会
  • 基于MATLAB实现基于距离的离群点检测算法
  • 冠县网站建设电话wordpress插件 电商
  • 【Android】RecyclerView LayoutManager 重写方法详解
  • 数据流通合规新基建 隐私计算平台的三重安全防线
  • MySQL-2--数据库的查询
  • 微信公众号商城网站开发wordpress 留言板制作
  • 虚幻基础:角色旋转控制角色视角控制
  • 【轨物方案】智慧供暖全景运营物联网解决方案
  • 超越“接收端”:解析视频推拉流EasyDSS在RTMP推流生态中的核心价值与中流砥柱作用