当前位置: 首页 > news >正文

DenseNet详解与实现

DenseNet详解与实现

    • 0. 前言
    • 1. DenseNet 架构
      • 1.1 DenseNet 原理
      • 1.2 瓶颈层与过渡层
    • 2. 构建 DenseNet

0. 前言

DenseNet 允许每个卷积直接访问输入和较低层的特征图,从而进一步改进了 ResNet。 通过利用瓶颈层 (Bottleneck) 和过渡层 (Transition),还可以使深度网络中的参数数量保持较低。

1. DenseNet 架构

我们已经学习了 ResNet 解决深度卷积网络中消失的梯度问题,DenseNet 使用另一种方法来解决梯度消失的问题。

1.1 DenseNet 原理

DenseNet

所有先前的特征图都将成为下一层的输入。请注意,第 lll 层的输入是所有先前特征图的串联。如果用操作 H(x)H(x)H(x) 表示 BN-ReLU-Conv2D,则层 lll 的输出为:
xl=H(x0,x1,x2,...,xl−1)x_l = H (x_0,x_1,x_2,...,x_{l-1}) xl=H(x0,x1,x2,...,xl1)
Conv2D 使用大小为 3 的卷积核。每层生成的特征图的数量称为增长率 kkk。通常,k=12k = 12k=12。因此,如果特征图 x0x_0x0 的数量为 k0k_0k0,则 4 层密集块末尾的特征图的总数将为 4k+k04k + k_04k+k0
DenseNet 建议在Dense块之前加上 BN-ReLU-Conv2D,这些特征图的数量是增长率的两倍:k0=2kk_0 = 2kk0=2k。在密集块的末尾,特征图的总数将为 4k+2k=6k4k + 2k = 6k4k+2k=6k
在输出层,DenseNet 建议在带有 softmax 层的 Dense() 之前执行平均池化。如果未使用数据增强,则必须在 DenseConv2D 之后跟随一个 dropout 层。

1.2 瓶颈层与过渡层

随着网络的深入,将出现两个新问题。首先,由于每一层都增加了 kkk 个特征图,因此层 lll 的输入数量为 (l–1)k+k0(l – 1)k + k_0(l–1)k+k0。 特征图在深层网络中快速增长,从而减慢了计算速度。
其次,与 ResNet 相似,随着网络的不断深入,特征图的大小将减小,从而增加了核的感受野大小。 如果 DenseNet 在合并操作中使用串联,则必须协调大小上的差异。
为了防止特征图的数量增加导致计算效率低,DenseNet 引入了瓶颈层。这个想法是,在每次串联之后,现在应用 1 x 1 卷积,其卷积核数量为 4k4k4k。这种降维技术可防止特征图的数量迅速增加。
然后,瓶颈层将 DenseNet 层修改为 BN-ReLU-Conv2D(1)-BN-ReLU-Conv2D(3),而不仅仅是 BN-ReLU-Conv2D(3)。将卷积核大小作为 Conv2D 的参数。对于瓶颈层,每个 Conv2D(3) 仅处理 4k4k4k 特征图,而不是处理 lll 层的 (l–1)k+k0(l – 1)k + k_0(l–1)k+k0。 例如,对于 101 层网络,最后一个 Conv2D(3) 的输入仍然是 48 个特征图,其中 k = 12

瓶颈层

为了解决特征图尺寸不匹配的问题,DenseNet 将深度网络划分为多个 Dense 块,这些块通过过渡层连接在一起。 在每个密集块中,特征图的大小保持不变。
过渡层的作用是在两个 Dense 块之间从一个特征图大小过渡到较小的特征图大小。 尺寸减小通常为一半。 这是通过平均池化层完成的。过渡层的输入是前一个 Dense 块中最后一个串联层的输出。

过渡层

但是,在将特征图传递到平均池化之前,使用 Conv2D(1) 将其数量减少通过压缩因子 0<θ<10<θ<10<θ<1DenseNet 中使用 θ=0.5θ= 0.5θ=0.5。例如,上一个 Dense 块的最后一个串联的输出为 (64,64,512),则在 Conv2D(1) 之后,特征图的新尺寸为 (64,64,256)。压缩和降维在一起,过渡层由 BN-Conv2D(1)-AveragePooling2D 层组成。实际上,批归一化在卷积层之前。

2. 构建 DenseNet

我们已经介绍了 DenseNet 的重要概念。接下来,我们将使用 tf.kerasCIFAR10 数据集构建 DenseNet 模型。模型架构如下所示:

模型架构
根据模型架构实现 DenseNet 模型:

inputs = keras.layers.Input(shape=input_shape)
x = keras.layers.BatchNormalization()(inputs)
x = keras.layers.Activation('relu')(x)
x = keras.layers.Conv2D(num_filters_bef_dense_block,kernel_size=3,padding='same',kernel_initializer='he_normal')(x)
x = keras.layers.concatenate([inputs,x])for i in range(num_dense_blocks):# 瓶颈层for j in range(num_bottleneck_layers):y = keras.layers.BatchNormalization()(x)y = keras.layers.Activation('relu')(y)y = keras.layers.Conv2D(4 * growth_rate,kernel_size=1,padding='same',kernel_initializer='he_normal')(y)if not data_augmentation:y = keras.layers.Dropout(0.2)(y)y = keras.layers.BatchNormalization()(y)y = keras.layers.Activation('relu')(y)y = keras.layers.Conv2D(growth_rate,kernel_size=3,padding='same',kernel_initializer='he_normal')(y)if not data_augmentation:y = keras.layers.Dropout(0.2)(y)x = keras.layers.concatenate([x,y])if i == num_dense_blocks - 1:continue#压缩特征图数量,并减小特征图尺寸num_filters_bef_dense_block += num_bottleneck_layers * growth_ratenum_filters_bef_dense_block = int(num_filters_bef_dense_block * compression_factor)y = keras.layers.BatchNormalization()(x)y = keras.layers.Conv2D(num_filters_bef_dense_block,kernel_size=1,padding='same',kernel_initializer='he_normal')(y)if not data_augmentation:y = keras.layers.Dropout(0.2)(y)x = keras.layers.AveragePooling2D()(y)x = keras.layers.AveragePooling2D(pool_size=8)(x)
x = keras.layers.Flatten()(x)
outputs = keras.layers.Dense(num_classes,kernel_initializer='he_normal',activation='softmax')(x)model = keras.Model(inputs=inputs,outputs=outputs)
model.compile(loss='categorical_crossentropy',optimizer=keras.optimizers.RMSprop(1e-3),metrics=['acc'])
model.summary()

使用 tf.keras 实现的 DenseNet 模型训练 200epoch 后,准确率达到 93.74%,训练过程中采用了数据增强技术,如果需构建更深层的 DenseNet 模型,需调整 growth_ratedepth 参数。


文章转载自:

http://svZPMlmc.wxcsm.cn
http://u8gs6XZO.wxcsm.cn
http://fylaqfmY.wxcsm.cn
http://13mctzoV.wxcsm.cn
http://c84URFYP.wxcsm.cn
http://K3e1Ij8p.wxcsm.cn
http://Fa6SKLpw.wxcsm.cn
http://MbV1w042.wxcsm.cn
http://epEdCpr9.wxcsm.cn
http://yrSvvwlO.wxcsm.cn
http://U2HGa8Tg.wxcsm.cn
http://UdMTkSds.wxcsm.cn
http://7vakPxPz.wxcsm.cn
http://nolibaAC.wxcsm.cn
http://YCHrs0Yg.wxcsm.cn
http://EiVtKZIg.wxcsm.cn
http://nDAGZDpx.wxcsm.cn
http://zffd9Yx9.wxcsm.cn
http://j1KMEcGB.wxcsm.cn
http://n0APJUEO.wxcsm.cn
http://UufnCiQX.wxcsm.cn
http://jYqzgeZP.wxcsm.cn
http://YwvVjYfE.wxcsm.cn
http://HXPrpTLz.wxcsm.cn
http://iCDI7iQL.wxcsm.cn
http://lY51CO1E.wxcsm.cn
http://etIsiLki.wxcsm.cn
http://uogOSREU.wxcsm.cn
http://J4Z4kwma.wxcsm.cn
http://uWi31NDM.wxcsm.cn
http://www.dtcms.com/a/378320.html

相关文章:

  • 计算机毕业设计 基于Hadoop豆瓣电影数据可视化分析设计与实现 Python 大数据毕业设计 Hadoop毕业设计选题【附源码+文档报告+安装调试
  • 25.9.11 QTday1作业
  • unity 陶艺制作模拟
  • Unity 三维数学方法
  • 【氮化镓】GaN基半导体器件电离辐射损伤基可靠性综述
  • 音视频demo
  • 相机Camera日志分析之三十六:相机Camera常见日志注释
  • 250911算法练习:递归
  • 双目相机原理
  • AI教育白皮书解读 | 医学教育数智化转型新机遇,“人工智能+”行动实践正当时
  • vue3自定义无缝轮播组件
  • 【每日算法】合并两个有序链表 LeetCode
  • 瑞萨RA家族新成员RA4C1,符合DLMS SUITE2表计安全规范、超低功耗、支持段码显示,专为智能表计应用开发
  • 【maxscript】矩阵对齐-武器残影
  • Java 黑马程序员学习笔记(进阶篇4)
  • XR 和 AI 在 Siggraph 2025 上主导图形的未来,获取gltf/glb格式
  • TikTok矩阵有哪些运营支撑方案?
  • 《基于深度学习的近红外条纹投影三维测量》-论文总结
  • 优选算法 100 题 —— 2 滑动窗口
  • MongoDB 在线安装-一键安装脚本(CentOS 7.9)
  • DeepSeek辅助编写的利用quick_xml把xml转为csv的rust程序
  • Rider中的Run/Debug配置对应的本地文件
  • 综合项目实践:基于基础语法核心的Python项目
  • 开始 ComfyUI 的 AI 绘图之旅-Flux.1图生图(八)
  • 供应商管理系统包含哪些模块?
  • MongoDB Atlas 云数据库实战:从零搭建全球多节点集群
  • Apache服务——搭建实验
  • “一半是火焰,一半是海水”,金融大模型的爆发与困局
  • 开源 C++ QT Widget 开发(十六)程序发布
  • MPC控制器C语言实现:基于一阶RL系统