PyTorch CNN 改进:全局平均池化与 CIFAR10 测试分析
昨天我们基于 CIFAR10 数据集搭建基础 CNN 并完成训练,今天继续对模型进行结构优化(引入全局平均池化),同时开展测试阶段的细致性能分析~
一、网络优化:全局平均池化(Global Average Pooling)的引入
在深度学习中,全局平均池化(GAP) 常被用于替代 “全连接层前的展平 + 大参数量全连接” 结构,优势显著:
- 减少参数量:大幅降低模型复杂度,缓解过拟合风险;
- 保留全局特征:对整个特征图做平均,能捕获全局空间信息,提升泛化能力。
我们重新定义Net
类来实践这一改进:
python
运行
import torch.nn as nn
import torch.nn.functional as F
# 自动选择设备(GPU优先,无GPU则用CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 前两层卷积+池化(结构与基础CNN类似)self.conv1 = nn.Conv2d(3, 16, 5) # 输入3通道,输出16通道,卷积核5×5self.pool1 = nn.MaxPool2d(2, 2) # 最大池化,核2×2、步长2self.conv2 = nn.Conv2d(16, 36, 5) # 输入16通道,输出36通道,卷积核5×5self.pool2 = nn.MaxPool2d(2, 2) # 第二次最大池化# 核心改进:全局平均池化层self.aap = nn.AdaptiveAvgPool2d(1) # 自适应将特征图压缩为1×1# 简化的全连接层(因GAP后特征维度固定,参数量骤减)self.fc3 = nn.Linear(36, 10) # 输入36(GAP后每个通道的平均值),输出10类(匹配CIFAR10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.aap(x) # 全局平均池化x = x.view(x.shape[0], -1) # 展平(保留批量维度,压缩特征维度)x = self.fc3(x)return xnet = Net()
net = net.to(device) # 将模型部署到指定设备
参数量对比:改进后模型总参数量仅约 1.6 万(此前基础 CNN 参数量约 17 万),参数量减少超 90%,模型更轻量化。
二、模型测试:从 “整体准确率” 到 “类别级准确率” 的深度分析
训练完成后,需在测试集验证模型性能。我们不仅要关注整体准确率,还要分析每个类别上的表现(模型对不同类别易出现 “偏科”)。
1. 整体测试集准确率
遍历测试集所有样本,统计预测正确的比例:
python
运行
correct = 0
total = 0
with torch.no_grad(): # 测试阶段无需计算梯度,加速运算for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)# 取预测概率最大的类别_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('测试集整体准确率: %d %%' % (100 * correct / total))
输出结果(示例):测试集整体准确率: 66 %
轻量化模型在 10000 张测试图像上达到 66% 准确率,性能与复杂度的平衡尚可。
2. 类别级准确率分析
进一步统计每个类别的预测准确率,可细致发现模型的 “强项” 与 “弱项”:
python
运行
# 初始化每个类别的正确数与总数
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)# 逐样本判断是否预测正确c = (predicted == labels).squeeze()for i in range(4): # 每批处理4张图像(batch_size=4)label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1# 打印每个类别的准确率
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
for i in range(10):print('类别 %5s 的准确率: %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
输出结果示例(因训练细节略有差异):
plaintext
类别 plane 的准确率: 72 %
类别 car 的准确率: 82 %
类别 bird 的准确率: 51 %
类别 cat 的准确率: 45 %
...(其余类别省略)
从结果可观察到:
- 模型对 car(汽车) 这类物体的识别准确率最高(82%),说明特征提取对 “汽车” 的模式捕捉更到位;
- 对 cat(猫) 的识别准确率较低(45%),可能因 “猫” 的形态多样(姿态、毛色等),增加了识别难度;
- 不同类别准确率的差异,也反映了 CIFAR10 数据集中各类别样本的 “易区分度” 不同。
三、改进思考:全局平均池化的价值
对比昨天的 “基础 CNN”(参数量大、全连接层复杂),今天的改进有两大明显优势:
- 模型更轻量:参数量从 17 万 + 骤减到 1.6 万,训练 / 推理速度更快,且更难过拟合;
- 特征更全局:全局平均池化替代 “展平 + 大全连接”,能更好利用特征图的全局信息,一定程度提升泛化性。
当然,准确率仍有提升空间(如增加训练轮次、加入数据增强、调整网络深度等),但从 “结构简化 + 性能平衡” 的角度,全局平均池化是非常实用的改进技巧~
通过今天的实践,我们不仅优化了网络结构,还学会了从 “整体 + 类别” 两个维度分析模型性能 —— 这对深度学习任务的迭代优化至关重要~