day9.27
在计算机视觉领域,卷积神经网络(CNN)是图像分类任务的核心工具。本文基于 PyTorch 框架,实现一个含全局平均池化的 CNN 模型,并从 “整体准确率” 和 “类别级准确率” 两个维度评估模型性能。
一、CNN 模型定义(引入全局平均池化)
全局平均池化(Global Average Pooling)可简化网络结构、减少参数数量,同时增强特征鲁棒性。以下是模型的完整定义:
python
运行
import torch.nn as nn
import torch.nn.functional as F# 自动选择设备:优先GPU,否则CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 第1层:卷积 + 最大池化self.conv1 = nn.Conv2d(3, 16, 5) # 输入3通道(RGB),输出16通道,卷积核5×5self.pool1 = nn.MaxPool2d(2, 2) # 池化核2×2,步长2# 第2层:卷积 + 最大池化self.conv2 = nn.Conv2d(16, 36, 5) # 输入16通道,输出36通道,卷积核5×5self.pool2 = nn.MaxPool2d(2, 2) # 池化核2×2,步长2# 全局平均池化:将任意尺寸特征图压缩为1×1self.aap = nn.AdaptiveAvgPool2d(1)# 全连接层:输入36(全局池化后的通道数),输出10(假设10分类)self.fc3 = nn.Linear(36, 10)def forward(self, x):# 第1轮:卷积 → ReLU激活 → 池化x = self.pool1(F.relu(self.conv1(x)))# 第2轮:卷积 → ReLU激活 → 池化x = self.pool2(F.relu(self.conv2(x)))# 全局平均池化x = self.aap(x)# 展平特征(为全连接层做准备)x = x.view(x.shape[0], -1)# 全连接层输出类别得分x = self.fc3(x)return x# 实例化模型并部署到设备
net = Net()
net = net.to(device)# 统计模型总参数数量
print("模型总参数数量:{}".format(sum(x.numel() for x in net.parameters())))
核心层解析:
- 卷积层(
nn.Conv2d
):提取图像局部特征(如边缘、纹理),通过 “输入通道→输出通道→卷积核大小” 定义。 - 最大池化层(
nn.MaxPool2d
):对特征图降采样,减少计算量同时保留关键特征。 - 全局平均池化(
nn.AdaptiveAvgPool2d(1)
):将每个通道的特征图压缩为单个平均值,替代传统 “大尺寸全连接层”,简化结构且更鲁棒。 - 全连接层(
nn.Linear
):将全局池化后的特征映射到 “类别空间”(如 10 分类任务输出 10 个得分)。
二、模型评估:整体与类别准确率
模型性能需在测试集上验证。我们不仅要计算 “整体准确率”,还要分析 “每个类别的准确率”,以定位模型的薄弱环节。
python
运行
# ========== 1. 整体准确率计算 ==========
correct = 0
total = 0
with torch.no_grad(): # 测试阶段关闭梯度,节省内存for data in testloader: # testloader为测试集数据加载器images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images) # 前向传播得到类别得分_, predicted = torch.max(outputs.data, 1) # 取得分最高的类别total += labels.size(0) # 累加总样本数correct += (predicted == labels).sum().item() # 累加正确数print('模型在测试集上的整体准确率: %d %%' % (100 * correct / total))# ========== 2. 类别级准确率计算 ==========
class_correct = list(0. for i in range(10)) # 每个类别的正确数
class_total = list(0. for i in range(10)) # 每个类别的总样本数
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze() # 压缩匹配结果的维度for i in range(4): # 假设每个批次含4个样本label = labels[i]class_correct[label] += c[i].item() # 累加当前类别的正确数class_total[label] += 1 # 累加当前类别的总样本数# 打印每个类别的准确率
for i in range(10):print('类别 %5s 的准确率: %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
评估逻辑解析:
- 整体准确率:遍历测试集所有样本,统计 “预测类别与真实标签一致” 的比例。
with torch.no_grad()
关闭梯度计算,避免测试阶段的冗余内存消耗。 - 类别级准确率:为每个类别单独统计 “正确数” 和 “总样本数”,最终计算每个类别的准确率。这能帮助我们发现模型在哪些类别上 “偏科”(如对 “猫” 分类准度低,但对 “汽车” 分类准度高)。
三、结果与分析(示例输出)
运行代码后,典型输出如下(以模拟结果为例):
plaintext
模型总参数数量:16022
模型在测试集上的整体准确率: 66 %
类别 plane 的准确率: 72 %
类别 car 的准确率: 82 %
类别 bird 的准确率: 51 %
类别 cat 的准确率: 45 %
...(剩余类别省略)
- 参数效率:模型仅约 1.6 万参数,这是全局平均池化的优势 —— 替代了传统大参数量的全连接层,让模型更 “轻量”。
- 整体性能:66% 的整体准确率反映模型对测试集的综合分类能力,但仍有优化空间。
- 类别差异:不同类别准确率差距明显(如 “car” 82% vs “cat” 45%),说明模型在某些类别(如 “cat”)的特征提取或分类能力不足,需针对性优化(如数据增强、调整网络结构)。
四、总结
本文基于 PyTorch 完成了 “含全局平均池化的 CNN 模型定义” 与 “多维度性能评估”:
- 模型设计:通过 “卷积 + 池化 + 全局平均池化 + 全连接” 的流程,在简化结构的同时保留特征表达能力。
- 评估维度:从 “整体准确率” 和 “类别级准确率” 双维度分析模型,既看综合性能,也定位薄弱环节。
若需进一步优化,可尝试数据增强、调整网络深度、引入正则化等策略。欢迎在评论区交流讨论~