当前位置: 首页 > news >正文

使用 PyTorch 构建并训练 CNN 模型

卷积神经网络(CNN)在计算机视觉领域占据核心地位,尤其在图像分类任务中表现出色。CIFAR-10 数据集是入门计算机视觉的经典数据集,包含10 类(飞机、汽车、鸟等)、分辨率为32×32的彩色图像,非常适合验证 CNN 模型的效果。

本文将基于 PyTorch 框架,从数据预处理模型定义训练过程性能评估,完整演示如何构建并训练 CNN 模型完成 CIFAR-10 分类任务。

一、环境准备与库导入

首先导入所需的 Python 库,确保 PyTorch、TorchVision 等已安装(可通过pip install torch torchvision安装)。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from collections import Counter
  • torch:PyTorch 核心库,提供张量运算、自动微分等功能。
  • torch.nn:神经网络模块库,包含卷积、池化、全连接层等。
  • torch.optim:优化器库,如 SGD、Adam 等。
  • torchvision:计算机视觉工具库,提供数据集、图像变换等功能。
  • numpy:数值计算库,辅助数据处理。

二、数据预处理与加载

为了提升模型泛化能力,训练集需做数据增强(随机裁剪、水平翻转);测试集保持 “干净”,仅做标准化等基础变换。

2.1 数据变换定义

# 设备自动检测:优先使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 训练集变换:数据增强 + 标准化
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),  # 随机裁剪(填充4像素后裁为32×32)transforms.RandomHorizontalFlip(),     # 随机水平翻转transforms.ToTensor(),                 # 转为Tensor(范围缩至[0,1])transforms.Normalize(                  # 标准化(减均值、除标准差)mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])# 测试集变换:仅标准化(无数据增强)
transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
])

2.2 数据集与数据加载器

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data',        # 数据保存路径train=True,           # 训练集download=False,       # 若本地已有,设为Falsetransform=transform_train
)
trainloader = DataLoader(trainset, batch_size=128,       # 批次大小shuffle=True,         # 打乱数据num_workers=2         # 多线程加载数据
)testset = torchvision.datasets.CIFAR10(root='./data', train=False,          # 测试集download=False, transform=transform_test
)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2
)# CIFAR-10的10个类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

数据增强通过增加训练样本多样性,减少模型过拟合;DataLoader负责按批次高效加载数据,提升训练效率。

三、CNN 模型定义

CNN 的核心是卷积层(提取局部特征)、池化层(缩小特征图并保留关键信息)、全连接层(分类决策)。下面定义两个模型:自定义 CNN(Net)和经典 LeNet。

3.1 自定义 CNN 模型(Net

class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 第一层卷积:输入3通道(彩色),输出16通道,卷积核5×5self.conv1 = nn.Conv2d(3, 16, 5)# 最大池化:核2×2,步长2self.pool1 = nn.MaxPool2d(2, 2)# 第二层卷积:输入16通道,输出36通道,卷积核5×5self.conv2 = nn.Conv2d(16, 36, 5)# 第二层池化self.pool2 = nn.MaxPool2d(2, 2)# 自适应平均池化:将特征图缩为1×1(通道数保留)self.aap = nn.AdaptiveAvgPool2d(1)# 全连接层:输入36(通道数),输出10(类别数)self.fc3 = nn.Linear(36, 10)def forward(self, x):# 卷积 → ReLU激活 → 池化x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))# 自适应池化(统一特征图尺寸)x = self.aap(x)# 展平(保留批次维度,其余维度合并)x = x.view(x.shape[0], -1)# 全连接层输出类别概率x = self.fc3(x)return x

3.2 经典 LeNet 模型

LeNet 是 CNN 的经典雏形,结构更简洁,适合理解 CNN 基本流程:

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)   # 输入3通道,输出6通道,核5×5self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 特征图经池化后尺寸为5×5self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# 卷积 → ReLU → 池化out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)# 展平为一维向量out = out.view(out.size(0), -1)# 全连接层 → ReLUout = F.relu(self.fc1(out))out = F.relu(self.fc2(out))# 输出层(无激活,配合交叉熵损失)out = self.fc3(out)return out

四、模型训练过程

训练的核心是 **“前向传播计算损失 → 反向传播求梯度 → 优化器更新参数”** 的循环。

4.1 初始化模型、损失函数与优化器

# 初始化模型并放到设备(GPU/CPU)
net = Net().to(device)
# 交叉熵损失(适合多分类任务)
criterion = nn.CrossEntropyLoss()
# SGD优化器(带动量,加速收敛)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4.2 训练循环

epochs = 10  # 训练轮数
for epoch in range(epochs):running_loss = 0.0  # 累计损失# 遍历训练数据加载器for i, data in enumerate(trainloader, 0):# 获取输入和标签,并转移到设备inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 1. 梯度清零(防止累积)optimizer.zero_grad()# 2. 前向传播:模型预测outputs = net(inputs)# 计算损失loss = criterion(outputs, labels)# 3. 反向传播:计算梯度loss.backward()# 4. 优化器更新参数optimizer.step()# 统计损失running_loss += loss.item()# 每2000个mini-batch打印一次平均损失if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')running_loss = 0.0print('Finished Training')

训练过程中,running_loss用于监控损失变化:若损失持续下降,说明模型在学习;若损失震荡或上升,可能是学习率过大或模型过拟合。

五、模型结构可视化(可选)

为了更清晰地了解模型各层的输入、输出和参数,我们实现类似 Keras model.summary()的功能:

import collectionsdef params_summary(input_size, model):def register_hook(module):def hook(module, input, output):# 获取模块类名class_name = str(module.__class__).split('.')[-1].split("'")[0]module_idx = len(summary)m_key = f'{class_name}-{module_idx+1}'summary[m_key] = collections.OrderedDict()# 记录输入/输出形状(批次维度设为-1,代表任意大小)summary[m_key]['input_shape'] = list(input[0].size())summary[m_key]['input_shape'][0] = -1summary[m_key]['output_shape'] = list(output.size())summary[m_key]['output_shape'][0] = -1# 统计参数数量params = 0if hasattr(module, 'weight') and hasattr(module.weight, 'size'):params += torch.prod(torch.LongTensor(list(module.weight.size())))summary[m_key]['trainable'] = module.weight.requires_gradif hasattr(module, 'bias') and hasattr(module.bias, 'size'):params += torch.prod(torch.LongTensor(list(module.bias.size())))summary[m_key]['nb_params'] = params# 仅对非容器类模块注册hookif not isinstance(module, nn.Sequential) and \not isinstance(module, nn.ModuleList) and \not (module == model):hooks.append(module.register_forward_hook(hook))# 生成随机输入用于前向传播if isinstance(input_size[0], (list, tuple)):x = [torch.rand(1, *in_size) for in_size in input_size]else:x = torch.rand(1, *input_size)summary = collections.OrderedDict()hooks = []# 注册所有模块的hookmodel.apply(register_hook)# 执行前向传播(触发hook)model(x)# 移除hook(避免影响后续操作)for h in hooks:h.remove()return summary# 查看Net的结构摘要
summary = params_summary((3, 32, 32), Net())
for layer, info in summary.items():print(f"Layer: {layer}")print(f"  Input Shape: {info['input_shape']}")print(f"  Output Shape: {info['output_shape']}")print(f"  Params: {info['nb_params']}")print(f"  Trainable: {info.get('trainable', False)}")print("-" * 50)

运行后,会打印每一层的输入形状、输出形状、参数数量和是否可训练,帮助我们验证模型结构是否符合预期。

六、模型评估(测试集性能)

训练完成后,在测试集上评估模型准确率(泛化能力):

correct = 0
total = 0
# 测试时无需计算梯度,加快速度
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)# 获取最大概率对应的类别_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'模型在10000张测试图像上的准确率: {100 * correct / total:.2f}%')

还可以进一步分析每个类别的准确率

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()# 假设batch_size为100,可根据实际调整循环次数for i in range(100):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print(f'{classes[i]}的准确率: {100 * class_correct[i] / class_total[i]:.2f}%')

七、总结与优化方向

本文完整演示了 “数据预处理→模型定义→训练→评估” 的深度学习流程。若想进一步提升性能,可尝试:

  1. 网络结构优化:增加卷积层、加入 BatchNorm 层、使用残差连接(ResNet 思想)。
  2. 超参数调整:增大学习率(配合学习率衰减)、调整批次大小、增加训练轮数。
  3. 优化器升级:改用 Adam 优化器(自适应学习率,收敛更快)。
  4. 正则化手段:加入 Dropout 层、权重衰减(L2 正则),缓解过拟合。

通过不断迭代优化,CIFAR-10 的分类准确率可逐步提升至 80% 以上~

http://www.dtcms.com/a/418702.html

相关文章:

  • 如何做电影网站狼视听seo外包优化服务商
  • blender布局工作区突然变得很卡
  • 【计算机视觉】图像去雾技术
  • 工信部网站icp备案号文艺范wordpress主题
  • 树莓派无法播放哔哩哔哩等视频
  • 华为芯片泄密案警示:用Curtain e-locker阻断内部数据泄露
  • 记一次达梦数据库的查询异常
  • 泸州市建设工程管理局网站58网站怎么做品牌推广
  • 个人主题网站设计论文北京seo推广系统
  • AI编程开发系统001-基于SpringBoot+Vue的旅游民宿租赁系统
  • 通用人工智能(AGI):从技术探索到社会重构的 2025 展望
  • 【Web前端|第五篇】Vue进阶(一):Axios工具和前端工程化
  • RISE论文阅读
  • LeetCode 416 分割等和子集
  • web开发,在线%车辆管理%系统,基于Idea,html,css,vue,java,springboot,mysql
  • 《安富莱嵌入式周报》第358期:USB4雷电开源示波器,2GHz带宽,3.2Gsps采样率,开源亚微米级精度3D运动控制平台,沉浸式8声道全景声音频录制
  • Axure: 多级多选可交互树状列表
  • 打破线制,告别电脑:积木易搭发布无线一体式3D扫描仪Toucan
  • 做电影网站的资源从哪里换wordpress新建音乐界面
  • Conda环境激活全指南:bash、conda activate与source activate详解
  • 英国网站后缀爱做的小说网站吗
  • 第四部分:VTK常用类详解(第98章 vtkBalloonWidget气球控件类)
  • Git 应用与规范指南
  • 查网站 备案信息有没有好的网站可以学做头发
  • Leetcode 14. 最长公共前缀
  • 在 Windows 上安装 WSL 并配置 SSH 服务,让 FinalShell 连接 Ubuntu
  • 【操作系统】进程 + 环境变量
  • win10离线安装.net framework3.5
  • 做网站时怎样图片上传怎么才能让图片不变形_有什么插件吗西安seo网站管理
  • 网站域名备案注册证书查询编程软件哪个好用