当前位置: 首页 > news >正文

用 PyTorch 搞定 CIFAR10 数据集

一、训练 CNN 模型

我们将分 6 个步骤完成项目

步骤 1:数据加载与预处理

深度学习的第一步是 “喂给模型高质量的数据”。我们需要加载 CIFAR10 数据集,并做必要的预处理(转张量、归一化),让数据符合模型输入要求。

python

运行

# 导入必要库
import torch
import torchvision
import torchvision.transforms as transforms
import os# 1. 确保数据目录存在(避免路径错误)
data_dir = './data'  # 数据集保存路径(与代码同文件夹)
if not os.path.exists(data_dir):os.makedirs(data_dir)print(f"已创建数据目录:{data_dir}")
else:print(f"数据目录已存在:{data_dir}")# 2. 数据预处理:将图像转为张量+归一化
# 为什么要预处理?
# - ToTensor():把PIL图像(0-255像素)转为PyTorch张量(0-1浮点数),模型只认张量;
# - Normalize():将张量归一化到[-1,1],加速模型收敛(避免像素值过大导致梯度爆炸)。
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # RGB三通道的均值和标准差都是0.5
])# 3. 加载训练集(自动下载缺失数据)
trainset = torchvision.datasets.CIFAR10(root=data_dir,    # 数据保存路径train=True,       # True=加载训练集,False=加载测试集download=True,    # 若路径下没有数据,自动从官网下载(约170MB)transform=transform  # 应用上面定义的预处理
)# 4. 创建训练数据加载器(按批次喂给模型)
# batch_size=4:每次给模型喂4张图,小批次训练更稳定;
# shuffle=True:训练时打乱数据顺序,避免模型学“数据顺序”而非“特征”;
# num_workers=0:Windows系统设为0(避免多进程错误),Linux/Mac可设2/4加速加载。
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0
)# 5. 加载测试集(流程和训练集一致,只是train=False、shuffle=False)
testset = torchvision.datasets.CIFAR10(root=data_dir,train=False,download=True,transform=transform
)testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,  # 测试时不需要打乱顺序num_workers=0
)# 6. 定义类别标签(与数据集的索引对应)
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 7. 验证数据是否加载成功
print(f"\n训练集样本数:{len(trainset)}(50000张,符合CIFAR10标准)")
print(f"测试集样本数:{len(testset)}(10000张,符合CIFAR10标准)")# 尝试加载一个批次的数据,确认格式正确
try:dataiter = iter(trainloader)  # 创建数据迭代器images, labels = next(dataiter)  # 获取第一个批次(4张图+4个标签)print(f"单批次图片形状:{images.shape}(格式:[批次大小, 通道数, 高度, 宽度])")print("数据加载成功!可以开始搭建模型了~")
except Exception as e:print(f"数据加载出错:{e}")print("解决方案:1. 检查网络(确保能下载数据);2. 手动下载数据放入./data目录(官网:https://www.cs.toronto.edu/~kriz/cifar.html)")

运行这段代码后:会自动在./data目录下下载 CIFAR10 数据,最后打印 “数据加载成功”,说明第一步完成!

步骤 2:数据可视化( Optional,但很重要)

加载数据后,我们可以用matplotlib显示几张图片,直观看看数据长什么样,也能确认预处理是否正确(比如有没有归一化反推)。

python

运行

import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline  # Jupyter中用,普通Python文件可删除(改用plt.show())# 定义图像显示函数:将预处理后的张量转回可显示的图像
def imshow(img):img = img / 2 + 0.5  # 反归一化:把[-1,1]转回[0,1](因为之前Normalize用了mean=0.5, std=0.5)npimg = img.numpy()  # 张量转numpy数组(matplotlib只认numpy)# 调整维度顺序:PyTorch是[通道数, 高度, 宽度],matplotlib需要[高度, 宽度, 通道数]plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.axis('off')  # 隐藏坐标轴,更美观plt.show()# 1. 获取训练集中的一个批次
dataiter = iter(trainloader)
images, labels = next(dataiter)  # images:4张图,labels:4个标签# 2. 显示这4张图(用torchvision.make_grid拼接成一张图)
print("训练集样本示例:")
imshow(torchvision.utils.make_grid(images))# 3. 打印这4张图的真实标签
print("真实类别:", ' '.join(f"{classes[labels[j]]:5s}" for j in range(4)))  # :5s表示占5个字符位,对齐显示

运行效果:会弹出一张包含 4 张图片的网格图,下方打印它们的真实类别(比如 “plane car bird cat”),能直观看到我们要分类的物体长什么样。

步骤 3:搭建 CNN 模型

这是项目的核心!我们将搭建一个简单但有效的 CNN,包含 2 个卷积块(卷积 + ReLU + 池化)和 2 个全连接层,结构如下:输入(3×32×32)→ 卷积1(16个5×5核)→ ReLU → 池化1 → 卷积2(36个3×3核)→ ReLU → 池化2 → 展平 → 全连接1(128维)→ ReLU → 全连接2(10维输出)

python

运行

import torch.nn as nn
import torch.nn.functional as F# 1. 选择计算设备:优先用GPU(训练更快),没有则用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"\n使用计算设备:{device}")  # 打印设备,确认是否用上GPU(显示cuda:0就是GPU)# 2. 定义CNN模型类(必须继承nn.Module,PyTorch规定)
class CIFAR10_CNN(nn.Module):def __init__(self):super(CIFAR10_CNN, self).__init__()  # 调用父类初始化# 第一个卷积块:卷积层+池化层# 卷积层(conv1):输入3通道(RGB),输出16通道(16个不同的卷积核,提取16种特征)# kernel_size=5:卷积核大小5×5,stride=1:步长1(每次移动1个像素)self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)# 池化层(pool1):最大池化,核大小2×2,步长2(压缩特征图尺寸到1/2)# 为什么用最大池化?保留局部最显著特征(比如边缘、纹理),同时减少计算量self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 第二个卷积块:卷积层+池化层# 卷积层(conv2):输入16通道(上一层输出),输出36通道(更复杂的特征)self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层1(fc1):将池化后的特征图展平为1维向量,再映射到128维# 输入维度怎么算?池化2后特征图尺寸是6×6(32→16→8→6,具体看前向传播注释),36个通道 → 36×6×6=1296self.fc1 = nn.Linear(36 * 6 * 6, 128)# 全连接层2(fc2):输出10维(对应10个类别),是模型的最终预测self.fc2 = nn.Linear(128, 10)# 前向传播:定义数据在模型中的流动路径(必须实现)def forward(self, x):# 输入x形状:[batch_size, 3, 32, 32]x = self.pool1(F.relu(self.conv1(x)))  # 卷积1 → ReLU激活 → 池化1# 经过conv1:[batch_size,16,28,28](32-5+1=28);经过pool1:[batch_size,16,14,14](28/2=14)x = self.pool2(F.relu(self.conv2(x)))  # 卷积2 → ReLU激活 → 池化2# 经过conv2:[batch_size,36,12,12](14-3+1=12);经过pool2:[batch_size,36,6,6](12/2=6)x = x.view(-1, 36 * 6 * 6)  # 展平:[batch_size, 36*6*6=1296](-1表示自动匹配batch_size)x = F.relu(self.fc1(x))  # 全连接1 → ReLU激活:[batch_size, 128]x = self.fc2(x)  # 全连接2:[batch_size, 10](10个类别得分)return x# 3. 创建模型实例,并移动到指定设备(GPU/CPU)
net = CIFAR10_CNN()
net = net.to(device)# 4. 查看模型结构和参数数量
print("\nCNN模型结构:")
print(net)
total_params = sum(x.numel() for x in net.parameters())  # 计算总参数数
print(f"\n模型总参数数量:{total_params}(约17万,属于轻量级模型,适合入门)")

关键理解

  • 卷积层负责 “提取局部特征”(比如猫的耳朵、汽车的车轮);
  • 池化层负责 “压缩特征、保留关键信息”(减少计算量,避免过拟合);
  • 全连接层负责 “整合全局特征、输出类别预测”(把分散的局部特征串联成全局判断)。

步骤 4:定义损失函数与优化器

模型搭好了,还需要两个 “助手”:

  • 损失函数:衡量模型预测的错误程度(比如预测是 “猫”,实际是 “狗”,损失值就大);
  • 优化器:根据损失函数的 “反馈” 调整模型参数,让损失越来越小(即模型越来越准)。

python

运行

import torch.optim as optim# 1. 学习率:控制参数更新的步长(太小收敛慢,太大可能不收敛,0.001是常用值)
LR = 0.001# 2. 损失函数:交叉熵损失(CrossEntropyLoss)
# 为什么用交叉熵?适合多分类任务,能有效衡量“预测概率分布”与“真实标签分布”的差异
# 注意:PyTorch的CrossEntropyLoss内置了Softmax激活,所以模型输出层不用加Softmax
criterion = nn.CrossEntropyLoss()# 3. 优化器:随机梯度下降(SGD)+ 动量(momentum)
# 为什么用SGD+动量?SGD是基础优化器,动量能加速收敛(避免在局部最优值徘徊)
# 可选优化器:Adam(收敛更快,适合新手,把下面一行注释解开即可用)
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# optimizer = optim.Adam(net.parameters(), lr=LR)print(f"损失函数:{criterion.__class__.__name__}")
print(f"优化器:{optimizer.__class__.__name__}(学习率:{LR},动量:0.9)")

步骤 5:训练模型(最耗时但最关键的一步)

训练的核心逻辑是 “循环迭代→正向传播算损失→反向传播算梯度→优化器更参数”,不断让模型 “学习” 数据规律。

python

运行

# 训练轮次(epoch):完整遍历训练集的次数(10轮足够初步收敛,太多会过拟合)
num_epochs = 10print(f"\n开始训练!共{num_epochs}轮,使用{device}加速...")# 记录训练开始时间
import time
start_time = time.time()# 外层循环:遍历每一轮
for epoch in range(num_epochs):running_loss = 0.0  # 累计每2000个批次的损失(用于打印进度)# 内层循环:遍历训练集中的每个批次# enumerate(trainloader, 0):返回(批次索引i,批次数据data)for i, data in enumerate(trainloader, 0):# 1. 获取批次数据(输入图像+真实标签),并移动到设备inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 2. 梯度清零!(关键步骤,避免上一轮的梯度累积)optimizer.zero_grad()# 3. 正向传播:输入模型,得到预测输出outputs = net(inputs)# 4. 计算损失:预测输出与真实标签的差异loss = criterion(outputs, labels)# 5. 反向传播:计算损失关于每个参数的梯度(指导参数更新)loss.backward()# 6. 优化器更新参数:根据梯度调整权重和偏置optimizer.step()# 7. 累计损失,每2000个批次打印一次running_loss += loss.item()  # loss.item():获取张量的数值(避免计算图占用内存)if i % 2000 == 1999:  # 每2000个批次触发一次(i从0开始,1999是第2000个批次)# 计算耗时(分钟)elapsed_time = (time.time() - start_time) / 60# 打印:轮次、批次索引、平均损失、耗时print(f"第{epoch+1:2d}轮, 第{i+1:5d}批次 | 平均损失:{running_loss/2000:.3f} | 耗时:{elapsed_time:.1f}分钟")running_loss = 0.0  # 重置累计损失# 计算总训练耗时
total_time = (time.time() - start_time) / 60
print(f"\n训练完成!总耗时:{total_time:.1f}分钟")# 保存训练好的模型(可选,下次用不用重新训练)
torch.save(net.state_dict(), './cifar10_cnn_model.pth')
print("模型参数已保存到:./cifar10_cnn_model.pth")

训练过程解读

  • 每轮训练会遍历 50000 张训练图(12500 个批次,因为 batch_size=4);
  • 每 2000 个批次打印一次 “平均损失”,正常情况下损失会逐渐下降(比如从 2.3 降到 0.8 左右);
  • 训练完成后,模型参数会保存到cifar10_cnn_model.pth,下次用的时候直接加载即可(不用重新训练)。

注意:如果训练时损失不下降,可能是学习率太大 / 太小,或模型结构不合适,可尝试调整学习率(如 0.0005、0.002)或换 Adam 优化器。

步骤 6:模型预测与评估

训练好的模型到底准不准?我们用测试集来验证,看看它能不能正确分类从未见过的图片。

python

运行

# 1. 加载测试集中的一个批次(用于可视化预测结果)
dataiter = iter(testloader)
images, labels = next(dataiter)  # 测试集的图片和真实标签# 2. 显示测试图片
print("\n测试集样本:")
imshow(torchvision.utils.make_grid(images))
print("真实类别:", ' '.join(f"{classes[labels[j]]:5s}" for j in range(4)))# 3. 模型预测:将测试图片移动到设备,输入模型
images, labels = images.to(device), labels.to(device)
outputs = net(images)  # outputs是10个类别的得分# 4. 解析预测结果:取得分最高的类别作为预测类别
# torch.max(outputs, 1):返回第1维度(类别维度)的最大值和对应的索引
_, predicted = torch.max(outputs, 1)  # predicted是预测类别的索引# 5. 打印预测结果
print("预测类别:", ' '.join(f"{classes[predicted[j]]:5s}" for j in range(4)))# 6. 评估模型在整个测试集上的准确率(更全面的性能衡量)
print("\n正在评估模型在整个测试集上的准确率...")
correct = 0  # 正确分类的样本数
total = 0    # 总样本数# 测试时不需要计算梯度(节省内存,加速计算)
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)  # 预测类别total += labels.size(0)  # 累加总样本数(每个批次4个,labels.size(0)=4)correct += (predicted == labels).sum().item()  # 累加正确数# 计算准确率
accuracy = 100 * correct / total
print(f"模型在测试集上的准确率:{accuracy:.2f}%")# 7. 查看每个类别的准确率(更细致的分析)
class_correct = list(0. for i in range(10))  # 每个类别的正确数
class_total = list(0. for i in range(10))    # 每个类别的总数with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()  # 压缩维度(batch_size=4 → 4个布尔值)# 遍历每个样本,统计每个类别的正确数和总数for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1# 打印每个类别的准确率
print("\n每个类别的准确率:")
for i in range(10):print(f"{classes[i]:10s}:{100 * class_correct[i] / class_total[i]:.2f}%")

预期结果

  • 可视化部分:会显示 4 张测试图,下方打印 “真实类别” 和 “预测类别”,如果预测正确,类别会一致;
  • 整体准确率:10 轮训练后,测试集准确率通常在 65%-75% 之间(如果用 Adam 优化器或增加轮次,能到 80% 以上);
  • 类别准确率:“飞机、汽车、船” 等轮廓清晰的类别准确率较高(70%+),“猫、狗、鸟” 等相似物体准确率较低(50%-60%),这符合人类视觉判断的直觉。

二、项目总结与拓展

1. 你学到了什么?

通过这个项目,你已经掌握了深度学习图像分类的完整流程:

  • 数据加载与预处理(ToTensor、Normalize);
  • CNN 模型搭建(卷积、池化、全连接层的配合);
  • 损失函数与优化器的选择(交叉熵、SGD/Adam);
  • 模型训练(正向传播、反向传播、参数更新);
  • 模型评估(准确率、类别细分分析)。

2. 如何提升模型性能?

如果想让准确率更高,可以尝试这些优化方向:

  • 增加训练轮次:比如训练 20-30 轮(注意防止过拟合,可加 Dropout 层);
  • 换用更优的优化器:比如 Adam(收敛更快,准确率更高);
  • 数据增强:增加训练数据多样性(如随机翻转、裁剪、调整亮度,代码见文末);
  • 加深 / 加宽网络:比如用 ResNet18、VGG16 等预训练模型(迁移学习,适合小数据集)。

3. 数据增强代码示例(拓展)

如果想尝试数据增强,只需修改步骤 1 中的transform

python

运行

transform = transforms.Compose([transforms.RandomCrop(32, padding=4),  # 随机裁剪(padding=4,避免边缘信息丢失)transforms.RandomHorizontalFlip(),     # 随机水平翻转(如汽车翻转后还是汽车,增加多样性)transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

http://www.dtcms.com/a/411669.html

相关文章:

  • VLMs距离空间智能还有多远的路要走?
  • 做网站北京德国网站建设
  • 网站建立安全连接失败软装设计公司加盟
  • 搭建个人博客:云服务器IP如何使用
  • iis网站asp.net部署网站建设运营费计入什么科目
  • 建设外贸营销型网站需要什么青岛网站设计定制
  • 券商 做网站圣都装饰的口碑怎么样
  • 【算法训练营Day26】动态规划part2
  • 河北衡水建设网站公司电话wordpress ajax登录插件
  • 网站源码怎么搭建最新新闻热点事件2023年10月
  • 城乡建设部网站广州市国外学校网站设计
  • 泊头网站建设公司wordpress删除主题之后
  • 一站式营销平台wordpress学校网站模板
  • LeetCode 算法题【简单】338. 比特位计数
  • 买房网站排名福州做网站建设公司
  • 爱思强交付第100套G10-SiC系统
  • 网站的建设要多少钱求推荐专门做借条的网站
  • 在线旅游攻略网站建设方案做网站要注册第35类商标吗
  • RocketMQ 核心知识整理:工作原理、常用命令与常见问题解决
  • 做养生网站怎么赚钱麻涌建设网站
  • 域名备案 没有网站网站建设意见建议表
  • Unity-Statemachinebehaviour状态机行为脚本
  • 网站问题图片房子网站有哪些
  • 孝感应城网站建设长春网站建设 找源晟
  • 如何设置网站服务器常州做网站哪家便宜
  • 单片机引脚的高电平和低电平范围值
  • 设计师可以做兼职的网站创建网站的基本步骤
  • 网站后台开发做什么凡科网网站建设
  • 什么是合同管理系统?6个核心功能介绍
  • 数据采集技术:03 有关实时采集