深度学习复现:CIFAR-10 数据集任务的实现(测试集)
什么是全局平均池化?
全局平均池化是一种特殊的池化操作,它将每个特征图直接降维为一个标量值。具体来说,对于一个尺寸为H×W的特征图,全局平均池化会计算该特征图上所有元素的平均值,最终输出一个1×1的特征值。
全局平均池化具有以下优势:
显著减少参数量:避免全连接层带来的参数爆炸
降低过拟合风险:由于参数减少,模型复杂度降低
增强平移不变性:对输入图像的位置变化更加鲁棒
更好的可解释性:每个特征图对
RGB通道的关键点:
三通道输入:分别对应红、绿、蓝颜色信息
并行处理:每个卷积核同时处理所有输入通道
特征融合:不同通道的信息在卷积过程中自动融合
通道数增加的原因:
特征多样性:更多通道意味着可以检测更多类型的特征
层次化学习:深层网络需要更复杂的特征表示
信息容量:增加通道数提高网络的信息处理能力
性能提升:适度的通道数增加通常能提升模型性能
测试
import torch
import torchvision
import torchvision.transforms as transformscorrect = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
全局平均池化
import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 36, 5)#self.fc1 = nn.Linear(16 * 5 * 5, 120)self.pool2 = nn.MaxPool2d(2, 2)#使用全局平均池化层self.aap = nn.AdaptiveAvgPool2d(1)self.fc3 = nn.Linear(36, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.aap(x)x = x.view(x.shape[0], -1)x = self.fc3(x)return xnet = Net()
net = net.to(device)print("net_gvp have {} parameters in total".format(sum(x.numel() for x in net.parameters())))