当前位置: 首页 > news >正文

PyTorch实现CIFAR-10图像分类:从数据加载到模型训练全流程

本文将详细介绍如何使用PyTorch框架构建一个卷积神经网络(CNN)来对CIFAR-10数据集进行图像分类。内容涵盖数据加载、网络构建、模型训练及测试评估等关键步骤。

1. 数据加载与预处理

首先,我们需要加载CIFAR-10数据集,并进行适当的预处理。以下是数据加载的代码:

import torch
import torchvision
import torchvision.transforms as transforms# 定义数据预处理流程
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)# CIFAR-10类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')

数据可视化

我们可以随机查看部分训练图像,确保数据加载正确:

import matplotlib.pyplot as plt
import numpy as np
%matplotlib inlinedef imshow(img):img = img / 2 + 0.5  # 反标准化npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# 随机显示4张图像
dataiter = iter(trainloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print('Labels:', ' '.join('%s' % classes[labels[j]] for j in range(4)))

2. 构建卷积神经网络

我们构建一个包含两个卷积层、两个池化层和两个全连接层的CNN模型:

import torch.nn as nn
import torch.nn.functional as Fclass CNNNet(nn.Module):def __init__(self):super(CNNNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5, stride=1)self.pool1 = nn.MaxPool2d(2, stride=2)self.conv2 = nn.Conv2d(16, 36, 3, stride=1)self.pool2 = nn.MaxPool2d(2, stride=2)self.fc1 = nn.Linear(1296, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 36 * 6 * 6)  # 展平操作x = F.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型并移至GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = CNNNet().to(device)print("Total parameters:", sum(p.numel() for p in net.parameters()))

该模型共有约17.3万个参数

3. 定义损失函数与优化器

我们使用交叉熵损失函数和随机梯度下降(SGD)优化器:

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练模型

接下来进行模型训练,共训练10个epoch:

for epoch in range(10):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()  # 梯度清零outputs = net(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新权重running_loss += loss.item()if i % 2000 == 1999:  # 每2000个batch打印一次损失print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

训练过程中损失值逐渐下降,表明模型正在有效学习

5. 模型测试与预测

最后,我们使用测试集评估模型性能,并随机查看一些预测结果:

# 随机选择4张测试图像
dataiter = iter(testloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print('GroundTruth:', ' '.join('%s' % classes[labels[j]] for j in range(4)))# 模型预测
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted:', ' '.join('%s' % classes[predicted[j]] for j in range(4)))

总结

本文完整展示了使用PyTorch实现CIFAR-10图像分类的流程,包括:

数据加载与预处理:使用torchvision加载CIFAR-10数据集并进行标准化。

网络构建:设计了一个包含卷积层、池化层和全连接层的CNN模型。

模型训练:使用SGD优化器和交叉熵损失函数进行训练。

模型测试:对测试集进行预测并可视化结果。

http://www.dtcms.com/a/410379.html

相关文章:

  • 鸿蒙应用内存优化全攻略:从泄漏排查到对象池实战
  • ReactUse 与ahook对比
  • 网站建设与维护属于什么岗位wordpress免费企业站主题
  • 长安网站设计仿照别的网站做
  • 如何快速定位bug,编写测试用例?
  • 【LeetCode 142】环形链表 II:寻找环的入口
  • 卷轴 缓冲绘制 超级玛丽demo5
  • 1.9 IP地址和Mac地址
  • C# WinForms的入门级画板实现
  • 云南网站建设方案简述营销型网站开发流程
  • 随时随地学算法:Hello-Algo与cpolar的远程学习方案
  • App 上架全流程指南,iOS 应用发布步骤、ipa 文件上传工具、TestFlight 分发与 App Store 审核经验分享
  • 网站建设公司推荐常德网站开发服务
  • 全球知名的Java Web开发平台Vaadin上线慧都网
  • 【QT】高级主题
  • 详细对比web请求post和put的区别
  • dedecms 营销网站模板免费下载专业设计网址青岛网站开发
  • 正在招 | 2025.9 福建 IT 相关岗位招聘信息
  • 树莓派4B+ubuntu20.04:不插显示器能不能正常开机?
  • 开发大型网站的最主流语言上海seo网站优化_搜索引擎排名_优化型企业网站建设_锦鱼网络
  • 从远程控制到AI赋能:ToDesk如何重塑未来办公新生态?
  • Python爬虫进阶:突破反爬机制(UA伪装+代理池+验证码识别)
  • 华为发布开源超节点架构,以开放战略叩响AI算力生态变局
  • 从格伦的角度理解信息哲学
  • 网站建设分金手指专业三十WordPress 多用户数据
  • obsidian git操作及踩坑记录:ssh秘钥设置以及推送到多个远程仓库
  • 【Linux】网络部分——网络基础(Socket 编程预备)
  • 【音频】SIP服务器Yate搭建
  • 贵阳网站建设宏思锐达网站挂服务器后图片不显示
  • @tanstack/react-query:React 服务器状态管理与数据同步解决方案