当前位置: 首页 > news >正文

PyTorch CNN 改进:全局平均池化与 CIFAR10 测试分析

昨天我们基于 CIFAR10 数据集搭建基础 CNN 并完成训练,今天继续对模型进行结构优化(引入全局平均池化),同时开展测试阶段的细致性能分析

一、网络优化:全局平均池化(Global Average Pooling)的引入

在深度学习中,全局平均池化(GAP) 常被用于替代 “全连接层前的展平 + 大参数量全连接” 结构,优势显著:

  • 减少参数量:大幅降低模型复杂度,缓解过拟合风险;
  • 保留全局特征:对整个特征图做平均,能捕获全局空间信息,提升泛化能力。

我们重新定义Net类来实践这一改进:

python

运行

import torch.nn as nn
import torch.nn.functional as F
# 自动选择设备(GPU优先,无GPU则用CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 前两层卷积+池化(结构与基础CNN类似)self.conv1 = nn.Conv2d(3, 16, 5)   # 输入3通道,输出16通道,卷积核5×5self.pool1 = nn.MaxPool2d(2, 2)    # 最大池化,核2×2、步长2self.conv2 = nn.Conv2d(16, 36, 5)  # 输入16通道,输出36通道,卷积核5×5self.pool2 = nn.MaxPool2d(2, 2)    # 第二次最大池化# 核心改进:全局平均池化层self.aap = nn.AdaptiveAvgPool2d(1)  # 自适应将特征图压缩为1×1# 简化的全连接层(因GAP后特征维度固定,参数量骤减)self.fc3 = nn.Linear(36, 10)        # 输入36(GAP后每个通道的平均值),输出10类(匹配CIFAR10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.aap(x)           # 全局平均池化x = x.view(x.shape[0], -1)  # 展平(保留批量维度,压缩特征维度)x = self.fc3(x)return xnet = Net()
net = net.to(device)  # 将模型部署到指定设备

参数量对比:改进后模型总参数量仅约 1.6 万(此前基础 CNN 参数量约 17 万),参数量减少超 90%,模型更轻量化。

二、模型测试:从 “整体准确率” 到 “类别级准确率” 的深度分析

训练完成后,需在测试集验证模型性能。我们不仅要关注整体准确率,还要分析每个类别上的表现(模型对不同类别易出现 “偏科”)。

1. 整体测试集准确率

遍历测试集所有样本,统计预测正确的比例:

python

运行

correct = 0
total = 0
with torch.no_grad():  # 测试阶段无需计算梯度,加速运算for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)# 取预测概率最大的类别_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('测试集整体准确率: %d %%' % (100 * correct / total))

输出结果(示例):测试集整体准确率: 66 %轻量化模型在 10000 张测试图像上达到 66% 准确率,性能与复杂度的平衡尚可。

2. 类别级准确率分析

进一步统计每个类别的预测准确率,可细致发现模型的 “强项” 与 “弱项”:

python

运行

# 初始化每个类别的正确数与总数
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)# 逐样本判断是否预测正确c = (predicted == labels).squeeze()for i in range(4):  # 每批处理4张图像(batch_size=4)label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1# 打印每个类别的准确率
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
for i in range(10):print('类别 %5s 的准确率: %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

输出结果示例(因训练细节略有差异):

plaintext

类别 plane 的准确率: 72 %
类别   car 的准确率: 82 %
类别  bird 的准确率: 51 %
类别   cat 的准确率: 45 %
...(其余类别省略)

从结果可观察到:

  • 模型对 car(汽车) 这类物体的识别准确率最高(82%),说明特征提取对 “汽车” 的模式捕捉更到位;
  • 对 cat(猫) 的识别准确率较低(45%),可能因 “猫” 的形态多样(姿态、毛色等),增加了识别难度;
  • 不同类别准确率的差异,也反映了 CIFAR10 数据集中各类别样本的 “易区分度” 不同。

三、改进思考:全局平均池化的价值

对比昨天的 “基础 CNN”(参数量大、全连接层复杂),今天的改进有两大明显优势:

  1. 模型更轻量:参数量从 17 万 + 骤减到 1.6 万,训练 / 推理速度更快,且更难过拟合;
  2. 特征更全局:全局平均池化替代 “展平 + 大全连接”,能更好利用特征图的全局信息,一定程度提升泛化性。

当然,准确率仍有提升空间(如增加训练轮次、加入数据增强、调整网络深度等),但从 “结构简化 + 性能平衡” 的角度,全局平均池化是非常实用的改进技巧~

通过今天的实践,我们不仅优化了网络结构,还学会了从 “整体 + 类别” 两个维度分析模型性能 —— 这对深度学习任务的迭代优化至关重要~

http://www.dtcms.com/a/414060.html

相关文章:

  • 精读C++20设计模式——创造型设计模式:单例模式
  • 网络实践——基于epoll_ET工作、Reactor设计模式的HTTP服务
  • 设计模式-行为型设计模式(针对对象之间的交互)
  • 选手机网站彩票网站开发制作模版
  • qq钓鱼网站在线生成器北京网站设计公司地址
  • SQL流程控制函数完全指南
  • 做电商网站前端的技术选型是移动商城积分和积分区别
  • 弄一个关于作文的网站怎么做微信分销网站建设官网
  • 怎么做站旅游网站上泡到妞平面设计师服务平台
  • 温室大棚建设 网站及排名转卖类似淘宝网站建设有哪些模板
  • 广西网站建设-好发信息网阿里邮箱 wordpress
  • 便捷网站建设费用搜关键词网站
  • 网站添加百度地图导航wordpress安装 centos
  • 如何自己建一个网站企业简介宣传片视频
  • 成都美誉网站设计建设优惠券网站
  • 整形网站源码一个网站如何做盈利
  • 机械设备东莞网站建设石家庄开发区网站建设
  • 代制作网站公司网站建设包括
  • 怎么手动安装网站程序搭建微信小程序
  • 郑州建网站371怎么把东西发布到网上卖
  • wordpress 点图片链接拼多多seo怎么优化
  • 石家庄做网站wordpress 文章摘要
  • 网站建设服务类型现状做兼职上哪个网站
  • 重庆网站seo排名用dw制作一个网站
  • 太原模板建站定制深圳网站建设及推广
  • vps 网站 需要绑定域名吗建设部网站拆除资质
  • 六安网站自然排名优化价格遵义网站建设网帮你
  • 网站版面设计流程包括哪些盐城手机网站建设
  • 重庆网站搭建昆明网站建设报价
  • 设计制作网站的公司深圳全网整合营销