当前位置: 首页 > news >正文

基于 PyTorch 的 CIFAR-10 图像分类学习总结

在本次学习中,我通过 PyTorch 实现了一个基于 CNN 的 CIFAR-10 图像分类模型,完整掌握了从数据加载、模型构建到训练评估的全流程。以下是具体学习内容总结,包含关键代码实现:

CIFAR-10图像分类完整代码

V1

创建时间:15:53

关键步骤解析

1. 数据处理

数据处理是深度学习任务的基础,主要包括:

  • 数据转换:使用transforms将图像转为张量并归一化,使模型更容易学习
  • 数据集加载:利用torchvision.datasets加载 CIFAR-10 数据集
  • 数据加载器:通过DataLoader实现批处理、打乱数据和多进程加载
  • 数据可视化:编写imshow函数直观查看数据,验证数据加载是否正确

2. 模型构建

CNN 模型是处理图像任务的有效工具,本模型结构包括:

  • 卷积层:使用nn.Conv2d提取图像特征,通过卷积核捕获局部特征
  • 池化层:使用nn.MaxPool2d降低特征图维度,减少计算量并增强鲁棒性
  • 全连接层:使用nn.Linear实现最终分类,将提取的特征映射到 10 个类别
  • 激活函数:使用 ReLU 增加模型非线性表达能力,解决梯度消失问题

3. 模型训练

训练过程是模型学习的核心,主要步骤包括:

  • 损失函数:选择交叉熵损失函数,适合多分类任务
  • 优化器:使用 SGD 优化器,通过学习率和动量控制参数更新
  • 训练循环:多轮迭代训练,每个批次包括前向传播、损失计算、反向传播和参数更新
  • 设备加速:自动检测 GPU 并利用 CUDA 加速训练过程

4. 模型评估

通过测试集验证模型性能:

  • 加载测试数据并可视化
  • 使用训练好的模型进行预测
  • 对比预测结果与真实标签,直观评估模型分类效果

学习心得

  1. 数据预处理对模型性能影响很大,合适的归一化能加速模型收敛
  2. 网络结构设计需要平衡复杂度和计算效率,过深或过浅的网络都可能影响性能
  3. 训练过程中的超参数(如学习率、批次大小、训练轮次)需要根据实际情况调整
  4. GPU 加速能显著提高训练速度,特别是对于图像等数据量大的任务
  5. 可视化是调试和理解模型的有效手段,有助于发现数据或模型中的问题

通过这个实例,我掌握了 PyTorch 的基本使用方法和 CNN 图像分类的完整流程,为后续更复杂的深度学习任务打下了基础。

# 基于PyTorch的CIFAR-10图像分类实现

# 一、准备工作:导入必要的库
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# 二、数据加载与预处理
# 1. 定义数据转换:将图像转为张量并归一化
transform = transforms.Compose(
[transforms.ToTensor(),  # 转换为PyTorch张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]  # 归一化到[-1, 1]范围
)

# 2. 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(
root='./data',        # 数据存储路径
train=True,           # 训练集
download=True,        # 如果本地没有数据则下载
transform=transform   # 应用数据转换
)

testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,          # 测试集
download=True,
transform=transform
)

# 3. 创建数据加载器
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=4,         # 批处理大小
shuffle=True,         # 训练时打乱数据顺序
num_workers=2         # 多进程加载数据
)

testloader = torch.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,        # 测试时不打乱顺序
num_workers=2
)

# 4. 定义类别标签
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 三、数据可视化
# 定义图像显示函数
def imshow(img):
img = img / 2 + 0.5  # 反归一化
npimg = img.numpy()  # 转换为numpy数组
# 调整通道顺序:从(C, H, W)转为(H, W, C)
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()

# 获取一些随机的训练图像
dataiter = iter(trainloader)
images, labels = next(dataiter)

# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

# 四、构建CNN模型
class CNNNet(nn.Module):
def __init__(self):
super(CNNNet, self).__init__()
# 第一个卷积层:3输入通道,16输出通道,5x5卷积核
self.conv1 = nn.Conv2d(3, 16, 5)
# 第一个池化层:2x2池化核,步长为2
self.pool = nn.MaxPool2d(2, 2)
# 第二个卷积层:16输入通道,36输出通道,3x3卷积核
self.conv2 = nn.Conv2d(16, 36, 3)
# 第一个全连接层
self.fc1 = nn.Linear(36 * 6 * 6, 128)
# 第二个全连接层(输出层,10个类别)
self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
# 第一个卷积块:卷积->ReLU->池化
x = self.pool(F.relu(self.conv1(x)))
# 第二个卷积块:卷积->ReLU->池化
x = self.pool(F.relu(self.conv2(x)))
# 展平特征图
x = x.view(-1, 36 * 6 * 6)
# 第一个全连接层->ReLU
x = F.relu(self.fc1(x))
# 输出层
x = self.fc2(x)
return x

# 实例化模型并移动到可用设备(GPU/CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = CNNNet()
net.to(device)

# 五、定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适合多分类任务
optimizer = optim.SGD(             # 随机梯度下降优化器
net.parameters(), 
lr=0.001,                      # 学习率
momentum=0.9                   # 动量参数
)

# 六、训练模型
for epoch in range(10):  # 训练10个epoch
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入数据和标签,并移动到设备
inputs, labels = data[0].to(device), data[1].to(device)

# 清零梯度
optimizer.zero_grad()

# 前向传播、计算损失、反向传播、参数更新
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# 打印训练状态
running_loss += loss.item()
if i % 2000 == 1999:    # 每2000个批次打印一次
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0

print('Finished Training')

# 七、模型测试
# 在测试集上进行预测
dataiter = iter(testloader)
images, labels = next(dataiter)

# 显示测试图像
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

# 进行预测
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)  # 获取预测概率最大的类别

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))

http://www.dtcms.com/a/411856.html

相关文章:

  • (附源码)医院门诊综合管理系统
  • 做外贸经常用的网站网站中单选按钮怎么做
  • 国家合同模板网站wordpress 首页伪静态
  • vite是什么
  • 建设银联官方网站帮别人发广告赚钱平台
  • 【轨物方案】轨物科技|以数智化技术赋能成套开关柜
  • Android Studio 编辑器汉化解决方法(超简单)
  • 网站运营与推广计划书怎么做做网站客户给不了素材
  • 保山网站建设多少钱wordpress 文章排序
  • STM32 Hardfault异常调试-笔记
  • 网站做的好坏主要看公司做网站好吗
  • 太原找工作网站网站怎么做显得简洁美观
  • 凤岗镇仿做网站做网站哪个语言好
  • Kanass入门到实战(4) - 如何快速导入Jira、Mantis数据
  • JavaScript 事件冒泡与事件捕获
  • 外贸网站源码怎么建wordpress使用百度分享插件下载
  • C语言基础【26】:结构体2
  • 项目计划书模板10篇win7优化大师
  • SQL Server提示:安装程序无法与下载服务器联系。请提供 Microsoft机器学习服务器安装文件的位置。。。。
  • 无人机表演行业二手设备市场与性价比分析
  • 快速建站公司怎么样wordpress读取父分类列表
  • 黄埔网站建设厦门网站排名
  • 好的ftp网站宁夏住房和城乡建设官方网站
  • Redis 7.0 新特性深度解读:迈向生产级的新纪元
  • wordpress网站实例网站怎么建设后台
  • JVM内存分配
  • 兴化网站建设网站开发用什么语言比较好
  • 四川网站建设找珊瑚云公司装修通知告示怎么写
  • 从 inode 角度深入分析软硬链接的内核实现与设计
  • 游戏开发中的状态管理与定时器