当前位置: 首页 > news >正文

使用 PyTorch 实现 CIFAR-10 图像分类:从数据加载到模型训练全流程

在深度学习入门实践中,CIFAR-10 数据集分类是一个经典案例。本文将详细介绍如何使用 PyTorch 构建一个卷积神经网络 (CNN) 来完成 CIFAR-10 图像分类任务,涵盖数据加载、模型构建、训练过程和结果评估的完整流程。

项目概述

CIFAR-10 是一个包含 10 个类别的彩色图像数据集,每个类别有 6000 张 32×32 像素的图像,共 60000 张图像,分为 50000 张训练集和 10000 张测试集。10 个类别分别是:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。

本项目将实现一个简单的卷积神经网络,使用 PyTorch 框架完成从数据加载到模型评估的全流程,并达到不错的分类效果。

一、数据加载与预处理

首先需要加载 CIFAR-10 数据集并进行必要的预处理:

import torch
import torchvision
import torchvision.transforms as transforms# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1, 1]范围
])# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='C:\\Users\\Administrator\\Desktop\\Untitled Folder\\cifar-10-batches-py',train=True,download=False,  # 已手动放置数据集,无需下载transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,  # 批次大小为4shuffle=True,  # 打乱数据顺序num_workers=0  # Windows环境建议设为0,避免多进程加载报错
)# 加载测试集
testset = torchvision.datasets.CIFAR10(root='C:\\Users\\Administrator\\Desktop\\Untitled Folder\\cifar-10-batches-py',train=False,download=False,transform=transform
)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,  # 测试集不需要打乱num_workers=0
)# CIFAR10类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

核心说明

  • transforms.Compose用于组合多个预处理操作
  • ToTensor()将 PIL 图像转换为 PyTorch 张量,并将像素值从 [0,255] 缩放到 [0,1]
  • Normalize()进行归一化,使数据均值为 0,标准差为 1,有助于模型收敛
  • DataLoader用于批量加载数据,并支持多进程加速(Windows 下建议关闭)
  • shuffle=True确保训练时每个 epoch 的数据顺序都不同,有助于模型泛化

二、数据可视化

加载数据后,我们可以编写一个函数来可视化数据,了解我们要处理的图像:

def show_images(tensor_images, labels, class_names):"""显示批量图像并打印对应标签"""# 反归一化(还原图像亮度)tensor_images = tensor_images / 2 + 0.5  # 转换为PIL图像网格img_grid = torchvision.utils.make_grid(tensor_images)img = torchvision.transforms.ToPILImage()(img_grid)# 显示图像(调用系统默认图像查看器)img.show()# 打印对应标签print("图像标签:", ' '.join(f"{class_names[labels[j]]:5s}" for j in range(len(labels))))# 测试:加载并显示一批训练数据
dataiter = iter(trainloader)
images, labels = next(dataiter)  # 获取一批数据(4张图像)# 显示图像和标签
show_images(images, labels, classes)

核心说明

  • 由于之前对图像进行了归一化,需要通过tensor_images / 2 + 0.5反归一化才能正确显示
  • torchvision.utils.make_grid()可以将多张图像组合成一个网格图像
  • ToPILImage()将张量转换回 PIL 图像格式以便显示

三、构建卷积神经网络模型

接下来我们构建一个简单的卷积神经网络:

import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()# 第一个卷积层:3输入通道,16输出通道,5x5卷积核self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)# 第一个池化层:2x2池化核,步长为2self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 第二个卷积层:16输入通道,36输出通道,3x3卷积核self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)# 第二个池化层:2x2池化核,步长为2self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 第一个全连接层self.fc1 = nn.Linear(1296, 128)# 第二个全连接层(输出层,10个类别)self.fc2 = nn.Linear(128, 10)def forward(self, x):# 卷积->激活->池化x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))# 展平特征图x = x.view(-1, 36 * 6 * 6)# 全连接层->激活x = F.relu(self.fc1(x))# 输出层x = self.fc2(x)return x# 初始化模型并移动到可用设备(GPU或CPU)
net = CNNNet()
net = net.to(device)# 打印模型总参数数量
print("net have {} parameters in total".format(sum(x.numel() for x in net.parameters())))

核心说明

  • 模型采用经典的卷积 - 池化 - 全连接结构
  • Conv2d层负责提取图像特征,通过卷积核捕获局部特征
  • MaxPool2d层进行下采样,减少特征图尺寸,同时保留重要特征
  • forward方法定义了数据在网络中的流动路径
  • x.view(-1, 36 * 6 * 6)将二维特征图展平为一维向量,以便输入全连接层
  • 代码会自动检测并使用 GPU(如果可用),否则使用 CPU

我们还可以通过以下代码获取网络的特征提取部分(前 4 层):

# 获取网络的特征提取部分(卷积层和池化层)
feature_extractor = nn.Sequential(*list(net.children())[:4])

四、定义损失函数和优化器

训练神经网络需要定义损失函数和优化器:

import torch.optim as optim# 学习率
LR = 0.001# 损失函数:交叉熵损失,适用于分类任务
criterion = nn.CrossEntropyLoss()
# 优化器:SGD(随机梯度下降)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 也可以使用Adam优化器,通常收敛更快
# optimizer = optim.Adam(net.parameters(), lr=LR)# 打印网络结构
print(net)

核心说明

  • 交叉熵损失 (CrossEntropyLoss) 是分类任务的常用损失函数
  • SGD 优化器带有动量 (momentum=0.9) 可以加速收敛并减少震荡
  • Adam 优化器通常收敛更快,但在某些任务上 SGD 可能泛化更好
  • 学习率 (lr) 是重要的超参数,过大会导致不收敛,过小会导致收敛太慢

五、训练模型

模型和数据准备就绪后,就可以开始训练了:

for epoch in range(10):  # 迭代10个epochrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# 获取训练数据并移动到设备inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 权重参数梯度清零optimizer.zero_grad()# 正向传播、计算损失、反向传播、参数更新outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 统计并显示损失值running_loss += loss.item()if i % 2000 == 1999:    # 每2000个mini-batch打印一次print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

核心说明

  • 训练过程由多个 epoch 组成,每个 epoch 遍历整个训练集一次
  • 每个 epoch 又分为多个 mini-batch,按批次处理数据
  • optimizer.zero_grad()清除上一次迭代的梯度
  • loss.backward()计算梯度(反向传播)
  • optimizer.step()根据梯度更新参数
  • 定期打印损失值可以监控训练进度,损失总体应该呈下降趋势

六、模型测试与评估

训练完成后,我们需要测试模型的性能:

# 使用训练好的模型进行预测
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))

我们还可以批量查看测试集的预测结果:

# 加载测试集并显示预测结果
dataiter = iter(testloader)
for i in range(100):  # 控制显示的批次数try:images, labels = next(dataiter)images, labels = images.to(device), labels.to(device)print(f"第{i+1}批图像:")show_pil_image(images.cpu())  # 转回CPU才能显示# 真实标签print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))# 预测结果outputs = net(images)_, predicted = torch.max(outputs, 1)print('Predicted:   ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))print('-'*50)  # 分隔线except StopIteration:break  # 数据迭代完则停止

核心说明

  • torch.max(outputs, 1)返回每个样本的最大预测值和对应的类别索引
  • 预测时需要将数据移动到与模型相同的设备(GPU 或 CPU)
  • 显示图像时需要将张量转回 CPU
  • 通过对比GroundTruth(真实标签)和Predicted(预测结果)可以直观了解模型性能

七、总结与改进方向

本项目实现了一个简单的 CNN 模型用于 CIFAR-10 分类,通过 10 个 epoch 的训练,通常可以达到 60%-70% 的准确率。这个结果对于基础模型来说已经不错,但还有很大的提升空间:

  1. 增加网络深度和宽度:可以尝试使用更深的网络结构,如 VGG、ResNet 等
  2. 数据增强:增加更多的数据增强手段,如随机裁剪、旋转、翻转等,提高模型泛化能力
  3. 调整超参数:尝试不同的学习率、批次大小、优化器等
  4. 正则化:添加 Dropout 层或 L2 正则化,减少过拟合
  5. 学习率调度:使用学习率衰减策略,使训练更稳定

通过这个项目,我们掌握了使用 PyTorch 进行图像分类的完整流程,包括数据加载与预处理、模型构建、训练过程和结果评估。这些技能可以迁移到其他图像分类任务中,是深度学习入门的重要实践。

http://www.dtcms.com/a/411296.html

相关文章:

  • 网站开发公司能否挣钱怎么在网站空间上传文件
  • 亭湖区建设局网站楼盘网站开发报价
  • java后端工程师进修ing(研一版‖day49)
  • opendds初入门之对inforepo模式运行探索
  • 简单公司网站最全的域名后缀
  • 比邻智联发布生活物联网家电应用白皮书和Cat.1模组新品
  • 第七章 Spring-Boot框架
  • 网站html静态化解决方案网站制作公司 北京
  • 金仓数据库实现电子证照系统从MongoDB平滑迁移,国产化替代迎来新典范
  • CAN总线学习(四)错误处理 STM32CAN外设一
  • 【OpenGL】LearnOpenGL学习笔记28 - 延迟渲染 Deferred Rendering
  • 莱芜梆子网站昆山网站建设需要多少钱
  • 站长交流装潢设计什么意思
  • web核心—HTTP
  • 线程池导入大数据量excel
  • Spring Boot 3.x + Security + OpenFeign:如何避免内部服务调用被重复拦截?
  • 全国免费发布信息网站大全wordpress 修改文章id
  • 公司网站设计费计入什么科目app科技网站建设
  • 从需求到实现:如何解决证件照标准化难题的?
  • C++第九篇:friend友元
  • 软件工程咋理解?用 “开奶茶店” 讲透瀑布模型 / 敏捷开发
  • 如何在WordPress中添加短代码
  • 资源型网站建设 需要多大硬盘招牌设计 创意logo
  • 数据库索引简介
  • 基于三角测量拓扑聚合优化的LSTM深度学习网络模型(TTAO-LSTM)的一维时间序列预测算法matlab仿真
  • 关键词网站查询产品展示网站源码php
  • TOGAF ® 标准与循环经济:为可持续与责任型 IT 而设计
  • C盘内存不足,清除或转移VS2022缓存文件Cache
  • 玉米病叶识别数据集,可识别褐斑,玉米锈病,玉米黑粉病,霜霉病,灰叶斑点,叶枯病等,使用yolo,coco,voc对4924张照片进行标注
  • 修改Linux上的ssh的默认端口号——及其客户端使用ssh连接不上Linux问题排查解决