当前位置: 首页 > news >正文

day9.27

在计算机视觉领域,卷积神经网络(CNN)是图像分类任务的核心工具。本文基于 PyTorch 框架,实现一个含全局平均池化的 CNN 模型,并从 “整体准确率” 和 “类别级准确率” 两个维度评估模型性能。

一、CNN 模型定义(引入全局平均池化)

全局平均池化(Global Average Pooling)可简化网络结构、减少参数数量,同时增强特征鲁棒性。以下是模型的完整定义:

python

运行

import torch.nn as nn
import torch.nn.functional as F# 自动选择设备:优先GPU,否则CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 第1层:卷积 + 最大池化self.conv1 = nn.Conv2d(3, 16, 5)   # 输入3通道(RGB),输出16通道,卷积核5×5self.pool1 = nn.MaxPool2d(2, 2)    # 池化核2×2,步长2# 第2层:卷积 + 最大池化self.conv2 = nn.Conv2d(16, 36, 5)  # 输入16通道,输出36通道,卷积核5×5self.pool2 = nn.MaxPool2d(2, 2)    # 池化核2×2,步长2# 全局平均池化:将任意尺寸特征图压缩为1×1self.aap = nn.AdaptiveAvgPool2d(1)# 全连接层:输入36(全局池化后的通道数),输出10(假设10分类)self.fc3 = nn.Linear(36, 10)def forward(self, x):# 第1轮:卷积 → ReLU激活 → 池化x = self.pool1(F.relu(self.conv1(x)))# 第2轮:卷积 → ReLU激活 → 池化x = self.pool2(F.relu(self.conv2(x)))# 全局平均池化x = self.aap(x)# 展平特征(为全连接层做准备)x = x.view(x.shape[0], -1)# 全连接层输出类别得分x = self.fc3(x)return x# 实例化模型并部署到设备
net = Net()
net = net.to(device)# 统计模型总参数数量
print("模型总参数数量:{}".format(sum(x.numel() for x in net.parameters())))

核心层解析:

  • 卷积层(nn.Conv2d:提取图像局部特征(如边缘、纹理),通过 “输入通道→输出通道→卷积核大小” 定义。
  • 最大池化层(nn.MaxPool2d:对特征图降采样,减少计算量同时保留关键特征。
  • 全局平均池化(nn.AdaptiveAvgPool2d(1):将每个通道的特征图压缩为单个平均值,替代传统 “大尺寸全连接层”,简化结构且更鲁棒。
  • 全连接层(nn.Linear:将全局池化后的特征映射到 “类别空间”(如 10 分类任务输出 10 个得分)。

二、模型评估:整体与类别准确率

模型性能需在测试集上验证。我们不仅要计算 “整体准确率”,还要分析 “每个类别的准确率”,以定位模型的薄弱环节。

python

运行

# ========== 1. 整体准确率计算 ==========
correct = 0
total = 0
with torch.no_grad():  # 测试阶段关闭梯度,节省内存for data in testloader:  # testloader为测试集数据加载器images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)  # 前向传播得到类别得分_, predicted = torch.max(outputs.data, 1)  # 取得分最高的类别total += labels.size(0)  # 累加总样本数correct += (predicted == labels).sum().item()  # 累加正确数print('模型在测试集上的整体准确率: %d %%' % (100 * correct / total))# ========== 2. 类别级准确率计算 ==========
class_correct = list(0. for i in range(10))  # 每个类别的正确数
class_total = list(0. for i in range(10))    # 每个类别的总样本数
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()  # 压缩匹配结果的维度for i in range(4):  # 假设每个批次含4个样本label = labels[i]class_correct[label] += c[i].item()  # 累加当前类别的正确数class_total[label] += 1             # 累加当前类别的总样本数# 打印每个类别的准确率
for i in range(10):print('类别 %5s 的准确率: %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

评估逻辑解析:

  • 整体准确率:遍历测试集所有样本,统计 “预测类别与真实标签一致” 的比例。with torch.no_grad() 关闭梯度计算,避免测试阶段的冗余内存消耗。
  • 类别级准确率:为每个类别单独统计 “正确数” 和 “总样本数”,最终计算每个类别的准确率。这能帮助我们发现模型在哪些类别上 “偏科”(如对 “猫” 分类准度低,但对 “汽车” 分类准度高)。

三、结果与分析(示例输出)

运行代码后,典型输出如下(以模拟结果为例):

plaintext

模型总参数数量:16022
模型在测试集上的整体准确率: 66 %
类别 plane 的准确率: 72 %
类别   car 的准确率: 82 %
类别  bird 的准确率: 51 %
类别   cat 的准确率: 45 %
...(剩余类别省略)
  • 参数效率:模型仅约 1.6 万参数,这是全局平均池化的优势 —— 替代了传统大参数量的全连接层,让模型更 “轻量”。
  • 整体性能:66% 的整体准确率反映模型对测试集的综合分类能力,但仍有优化空间。
  • 类别差异:不同类别准确率差距明显(如 “car” 82% vs “cat” 45%),说明模型在某些类别(如 “cat”)的特征提取或分类能力不足,需针对性优化(如数据增强、调整网络结构)。

四、总结

本文基于 PyTorch 完成了 “含全局平均池化的 CNN 模型定义” 与 “多维度性能评估”:

  • 模型设计:通过 “卷积 + 池化 + 全局平均池化 + 全连接” 的流程,在简化结构的同时保留特征表达能力。
  • 评估维度:从 “整体准确率” 和 “类别级准确率” 双维度分析模型,既看综合性能,也定位薄弱环节。

若需进一步优化,可尝试数据增强、调整网络深度、引入正则化等策略。欢迎在评论区交流讨论~

http://www.dtcms.com/a/415222.html

相关文章:

  • 做动画人设有哪些网站可以借鉴谷歌chrome浏览器下载
  • c++ 之多态虚函数表
  • 全屏网站 图片优化网站主机免费
  • 谷歌广告联盟网站同一个网站绑定多个域名
  • Java 大视界 -- Java 大数据机器学习模型在金融产品创新与客户需求匹配中的实战应用(417)
  • 美团网站是用什么做的网站开发企业开发
  • C语言风格哈希表vs C++风格哈希表的区别
  • 做数据分析网站做网站与数据库的关系
  • 六节tslib移植 、Qt移植到嵌入式linux
  • 做动漫图片的网站seo推广费用
  • 设计模式与原则精要
  • asp网站怎么做301定向系统商店
  • 大连html5网站建设价格泉州快速建站模板
  • LeetCode:64.搜索二维矩阵
  • 特殊矩阵的压缩存储
  • Qwen3-Omni多模态prompt输入解析
  • CVPR-2025 | 具身导航指令高效生成!MAPInstructor:基于场景图的导航指令生成Prompt调整策略
  • PRP (Product Requirement Prompts) - AI辅助开发提示词库
  • 昆明网站seo多少钱金舵设计园在线设计平台
  • AI识图 + MinIO图床 + 钉钉推送:打造全自动水质监测系统
  • EIGRP
  • 旅游电子商务网站开发方案网站运营数据周报表怎么做
  • 计算机视觉:人脸关键点定位与轮廓绘制
  • 手机网站建设基本流程专业的集团网站开发开发
  • Spring AI Alibaba:Java生态下的智能体开发全栈解决方案
  • 这么做网站网站三合一
  • Kurt-Blender零基础教程:第3章:材质篇——第3节:给模型上材质
  • Unity-导航寻路系统
  • 辽宁网站建设学校赣州建设局网站
  • 高功耗显卡兼容性难题全解析