当前位置: 首页 > news >正文

深度学习复现:CIFAR-10 数据集任务的实现(测试集)

什么是全局平均池化?

全局平均池化是一种特殊的池化操作,它将每个特征图直接降维为一个标量值。具体来说,对于一个尺寸为H×W的特征图,全局平均池化会计算该特征图上所有元素的平均值,最终输出一个1×1的特征值。

全局平均池化具有以下优势:

  1. 显著减少参数量:避免全连接层带来的参数爆炸

  2. 降低过拟合风险:由于参数减少,模型复杂度降低

  3. 增强平移不变性:对输入图像的位置变化更加鲁棒

  4. 更好的可解释性:每个特征图对

RGB通道的关键点:

  1. 三通道输入:分别对应红、绿、蓝颜色信息

  2. 并行处理:每个卷积核同时处理所有输入通道

  3. 特征融合:不同通道的信息在卷积过程中自动融合

通道数增加的原因:

  1. 特征多样性:更多通道意味着可以检测更多类型的特征

  2. 层次化学习:深层网络需要更复杂的特征表示

  3. 信息容量:增加通道数提高网络的信息处理能力

  4. 性能提升:适度的通道数增加通常能提升模型性能

测试

import torch
import torchvision
import torchvision.transforms as transformscorrect = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

全局平均池化

import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 36, 5)#self.fc1 = nn.Linear(16 * 5 * 5, 120)self.pool2 = nn.MaxPool2d(2, 2)#使用全局平均池化层self.aap = nn.AdaptiveAvgPool2d(1)self.fc3 = nn.Linear(36, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.aap(x)x = x.view(x.shape[0], -1)x = self.fc3(x)return xnet = Net()
net = net.to(device)print("net_gvp have {} parameters in total".format(sum(x.numel() for x in net.parameters())))

http://www.dtcms.com/a/415481.html

相关文章:

  • 【Spring 1】Spring IoC:颠覆传统编程的控制反转艺术
  • 如何为网站做面包屑导航网站必须要备案吗
  • AI 动画视频创作:技巧升级与行业未来趋势
  • 数字化转型:概念性名词浅谈(第五十三讲)
  • 制作网站参考案例wordpress推介联盟
  • 当遇到人生低谷期,该怎么度过?别装坚强,熬过去才是真本事
  • 电商网站开发报价单濮阳网站建设陈帅
  • 医联媒体网站建设网站建设网站制作公司
  • Detectron2 - 下一代目标检测与分割算法库
  • CSS过渡效果完全指南
  • 木门行业网站该怎么做封面制作网站
  • AIPyApp - Python 智能执行环境
  • 深度学习中Bootstrap详解
  • 网站关键字优化合同深圳网站制作公司资讯
  • 网络销售型网站有哪些内容百度推广培训机构
  • html制作一个个人主页网站wordpress首页调用指定文章
  • 安宝特科技丨【行业首发】Vuzix LX1智能眼镜:仓储物流的下一代智能助手
  • 无锡建行网站重庆网站备案最快几天
  • 河津网站建设湖南建设工程信息网官网
  • Ubuntu服务器版增加中文支持
  • 宁波网站推广营销江苏中南建设集团网站是多少
  • 那些网站企业可以免费展示动画制作软件flash官方下载
  • C++笔记(面向对象)类的定义
  • 电子信息工程专业课《数字信号处理》课程简介
  • 【攻防实战】对抗中的特殊场景上线cs和msf
  • 如何利用网站做demo北京公司网站建设报价
  • 显存带宽瓶颈突破:基于TensorRT的实时4K视频渲染优化
  • 陕西网站制作公司宁波网站建设排名
  • 网站开发设计会议网站怎么做
  • OpenHarmony中的系统服务管理配置讲解