当前位置: 首页 > news >正文

深度学习篇---DenseNet网络结构

在 PyTorch 中实现 DenseNet(以经典的 DenseNet-121 为例),核心是实现它的 "密集连接" 机制 —— 每一层都与前面所有层通过通道拼接(Concatenate)直接连接。我们从基础模块开始,一步步搭建,确保你能理解每个部分的作用。

一、先明确 DenseNet 的核心结构

DenseNet 的结构可以概括为:

输入(224×224彩色图) → 
初始卷积层 → 初始池化层 → 
4个Dense块(每个Dense块包含多个密集连接的卷积层) → 
每个Dense块后接过渡层(下采样+通道压缩) → 
全局平均池化 → 全连接层(输出1000类)

其中,Dense 块(含密集连接)和过渡层是核心组件。

二、PyTorch 实现 DenseNet 的步骤

步骤 1:导入必要的库

和之前实现其他 CNN 一样,先准备好工具:

import torch  # 核心库
import torch.nn as nn  # 神经网络层
import torch.optim as optim  # 优化器
from torch.utils.data import DataLoader  # 数据加载器
from torchvision import datasets, transforms  # 图像数据处理

步骤 2:实现 DenseNet 的基础组件 —— 瓶颈层(Bottleneck)

DenseNet 的每一层都用 "瓶颈层" 设计,通过 1×1 卷积降维,避免通道数过多导致计算量爆炸:

class Bottleneck(nn.Module):def __init__(self, in_channels, growth_rate, dropout_rate=0.0):super(Bottleneck, self).__init__()# 瓶颈层结构:BN → ReLU → 1×1 Conv → BN → ReLU → 3×3 Convself.bn1 = nn.BatchNorm2d(in_channels)self.relu = nn.ReLU(inplace=True)# 1×1卷积:降维到4×growth_rate(论文推荐)self.conv1 = nn.Conv2d(in_channels, 4 * growth_rate,kernel_size=1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(4 * growth_rate)# 3×3卷积:输出growth_rate个通道(控制每一层的新增通道数)self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate,kernel_size=3, stride=1, padding=1, bias=False)self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else Nonedef forward(self, x):# x是前面所有层输出的拼接(密集连接的输入)out = self.bn1(x)out = self.relu(out)out = self.conv1(out)  # 1×1降维if self.dropout is not None:out = self.dropout(out)out = self.bn2(out)out = self.relu(out)out = self.conv2(out)  # 3×3提取特征if self.dropout is not None:out = self.dropout(out)# 关键:将当前层输出与输入拼接(密集连接的核心)# 输入x是前面所有层的拼接,这里再拼上当前层输出return torch.cat([x, out], dim=1)

通俗解释

  • growth_rate(增长率 k)是 DenseNet 的核心参数,控制每一层新增的通道数(比如 k=32,每一层输出 32 个新通道);
  • 1×1 卷积先将输入通道数降到4×k(降维,减少计算量),3×3 卷积再输出k个通道;
  • 最终通过torch.cat将当前层输出与输入(前面所有层的特征)拼接,实现密集连接。

步骤 3:实现 Dense 块(Dense Block)

一个 Dense 块由多个 Bottleneck 层组成,所有层通过密集连接串联:

def _make_dense_block(in_channels, num_layers, growth_rate, dropout_rate):"""创建一个Dense块in_channels: 输入通道数num_layers: 块内的Bottleneck层数growth_rate: 增长率k"""layers = []for _ in range(num_layers):# 每个Bottleneck的输入通道数会随层数增加(因为密集连接不断拼接)layers.append(Bottleneck(in_channels, growth_rate, dropout_rate))# 更新输入通道数(加上新增的growth_rate个通道)in_channels += growth_ratereturn nn.Sequential(*layers), in_channels

举例
如果输入通道 = 64,num_layers=6,growth_rate=32:

  • 第 1 层输入 = 64 → 输出拼接后 = 64+32=96
  • 第 2 层输入 = 96 → 输出拼接后 = 96+32=128
  • ...
  • 第 6 层输入 = 64+5×32=224 → 输出拼接后 = 224+32=256
    最终 Dense 块输出通道 = 256

步骤 4:实现过渡层(Transition Layer)

过渡层用于连接两个 Dense 块,作用是 "下采样(尺寸减半)+ 通道压缩":

class Transition(nn.Module):def __init__(self, in_channels, out_channels, dropout_rate=0.0):super(Transition, self).__init__()# 过渡层结构:BN → ReLU → 1×1 Conv(通道压缩) → 2×2 AvgPool(下采样)self.bn = nn.BatchNorm2d(in_channels)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1, padding=0, bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else Nonedef forward(self, x):out = self.bn(x)out = self.relu(out)out = self.conv(out)  # 通道压缩if self.dropout is not None:out = self.dropout(out)out = self.pool(out)  # 下采样(尺寸减半)return out

通道压缩逻辑
过渡层的输出通道数 = 输入通道数 × 压缩因子 θ(论文中 θ=0.5)。例如输入 256 通道,过渡层输出 128 通道(256×0.5)。

步骤 5:搭建 DenseNet 完整网络(以 DenseNet-121 为例)

DenseNet-121 的结构是:4 个 Dense 块,分别包含 6、12、24、16 层 Bottleneck,增长率 k=32:

class DenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, dropout_rate=0.0, compression=0.5):super(DenseNet, self).__init__()# 初始卷积层:输出通道数=2×growth_rate(论文推荐)in_channels = 2 * growth_rateself.features = nn.Sequential(nn.Conv2d(3, in_channels, kernel_size=7, stride=2, padding=3, bias=False),nn.BatchNorm2d(in_channels),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)  # 初始池化)# 构建4个Dense块和过渡层for i, num_layers in enumerate(block_config):# 1. 添加Dense块dense_block, in_channels = _make_dense_block(in_channels, num_layers, growth_rate, dropout_rate)self.features.add_module(f'denseblock{i+1}', dense_block)# 2. 添加过渡层(最后一个Dense块后没有过渡层)if i != len(block_config) - 1:# 过渡层输出通道数=输入通道数×压缩因子out_channels = int(in_channels * compression)trans = Transition(in_channels, out_channels, dropout_rate)self.features.add_module(f'transition{i+1}', trans)in_channels = out_channels  # 更新输入通道数# 最终的BN和ReLUself.features.add_module('bn_final', nn.BatchNorm2d(in_channels))self.features.add_module('relu_final', nn.ReLU(inplace=True))# 全局平均池化和全连接层self.global_pool = nn.AdaptiveAvgPool2d((1, 1))self.classifier = nn.Linear(in_channels, num_classes)def forward(self, x):features = self.features(x)  # 经过所有Dense块和过渡层out = self.global_pool(features)  # 全局池化out = out.view(out.size(0), -1)  # 拉平成向量out = self.classifier(out)  # 全连接层输出return out

结构解释

  • block_config=(6,12,24,16):4 个 Dense 块的层数,总和 6+12+24+16=58,加上初始卷积、过渡层等,总层数约 121(DenseNet-121 名称由来);
  • compression=0.5:过渡层的压缩因子 θ,控制通道数减少比例;
  • 每个 Dense 块后接过渡层(最后一个除外),实现特征图尺寸从 56×56→28×28→14×14→7×7 的逐步缩减。

步骤 6:准备数据(用 CIFAR-10 演示)

DenseNet 适合高精度分类任务,我们用 CIFAR-10(10 类)演示,输入尺寸调整为 224×224:

# 数据预处理:缩放+裁剪+翻转+标准化
transform = transforms.Compose([transforms.Resize(256),  # 缩放为256×256transforms.RandomCrop(224),  # 随机裁剪成224×224transforms.RandomHorizontalFlip(),  # 随机翻转(数据增强)transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)# 批量加载数据(DenseNet内存占用较高,batch_size适当减小)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

步骤 7:初始化模型、损失函数和优化器

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DenseNet-121配置:增长率32,4个Dense块,输出10类(CIFAR-10)
model = DenseNet(growth_rate=32,block_config=(6, 12, 24, 16),num_classes=10,dropout_rate=0.2  # 可选:添加dropout防止过拟合
).to(device)criterion = nn.CrossEntropyLoss()  # 交叉熵损失
# 优化器:推荐用Adam,学习率0.001
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

步骤 8:训练和测试函数

DenseNet 训练时内存占用较高,训练逻辑和之前类似但需注意显存使用:

def train(model, train_loader, criterion, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()  # 清空梯度output = model(data)   # 模型预测loss = criterion(output, target)  # 计算损失loss.backward()        # 反向传播optimizer.step()       # 更新参数# 打印进度if batch_idx % 50 == 0:print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')def test(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')

步骤 9:开始训练和测试

DenseNet 收敛较慢,建议训练 30-50 轮:

for epoch in range(1, 31):train(model, train_loader, criterion, optimizer, epoch)test(model, test_loader)

在 CIFAR-10 上,DenseNet-121 训练充分后准确率能达到 94% 以上,远高于 MobileNet 和 ShuffleNet,体现了其强大的特征融合能力。

三、完整代码总结

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 1. 实现瓶颈层(Bottleneck)
class Bottleneck(nn.Module):def __init__(self, in_channels, growth_rate, dropout_rate=0.0):super(Bottleneck, self).__init__()self.bn1 = nn.BatchNorm2d(in_channels)self.relu = nn.ReLU(inplace=True)# 1×1卷积降维到4×growth_rateself.conv1 = nn.Conv2d(in_channels, 4 * growth_rate,kernel_size=1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(4 * growth_rate)# 3×3卷积输出growth_rate个通道self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate,kernel_size=3, stride=1, padding=1, bias=False)self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else Nonedef forward(self, x):out = self.bn1(x)out = self.relu(out)out = self.conv1(out)if self.dropout is not None:out = self.dropout(out)out = self.bn2(out)out = self.relu(out)out = self.conv2(out)if self.dropout is not None:out = self.dropout(out)# 密集连接:拼接输入和当前层输出return torch.cat([x, out], dim=1)# 2. 实现Dense块
def _make_dense_block(in_channels, num_layers, growth_rate, dropout_rate):layers = []for _ in range(num_layers):layers.append(Bottleneck(in_channels, growth_rate, dropout_rate))in_channels += growth_rate  # 更新输入通道数(累加增长率)return nn.Sequential(*layers), in_channels# 3. 实现过渡层
class Transition(nn.Module):def __init__(self, in_channels, out_channels, dropout_rate=0.0):super(Transition, self).__init__()self.bn = nn.BatchNorm2d(in_channels)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=1, padding=0, bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)  # 下采样self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else Nonedef forward(self, x):out = self.bn(x)out = self.relu(out)out = self.conv(out)  # 通道压缩if self.dropout is not None:out = self.dropout(out)out = self.pool(out)  # 尺寸减半return out# 4. 搭建DenseNet完整网络(DenseNet-121)
class DenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, dropout_rate=0.0, compression=0.5):super(DenseNet, self).__init__()# 初始卷积层in_channels = 2 * growth_rateself.features = nn.Sequential(nn.Conv2d(3, in_channels, kernel_size=7, stride=2, padding=3, bias=False),nn.BatchNorm2d(in_channels),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))# 构建4个Dense块和过渡层for i, num_layers in enumerate(block_config):# 添加Dense块dense_block, in_channels = _make_dense_block(in_channels, num_layers, growth_rate, dropout_rate)self.features.add_module(f'denseblock{i+1}', dense_block)# 添加过渡层(最后一个块除外)if i != len(block_config) - 1:out_channels = int(in_channels * compression)trans = Transition(in_channels, out_channels, dropout_rate)self.features.add_module(f'transition{i+1}', trans)in_channels = out_channels# 最终的BN和ReLUself.features.add_module('bn_final', nn.BatchNorm2d(in_channels))self.features.add_module('relu_final', nn.ReLU(inplace=True))# 分类部分self.global_pool = nn.AdaptiveAvgPool2d((1, 1))self.classifier = nn.Linear(in_channels, num_classes)def forward(self, x):features = self.features(x)out = self.global_pool(features)out = out.view(out.size(0), -1)  # 拉平特征out = self.classifier(out)return out# 5. 准备CIFAR-10数据
transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)# 6. 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DenseNet(growth_rate=32,block_config=(6, 12, 24, 16),num_classes=10,dropout_rate=0.2
).to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)# 7. 训练函数
def train(model, train_loader, criterion, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 50 == 0:print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.6f}')# 8. 测试函数
def test(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()print(f'Test Accuracy: {100 * correct / total:.2f}%')# 9. 开始训练和测试
for epoch in range(1, 31):train(model, train_loader, criterion, optimizer, epoch)test(model, test_loader)

四、关键知识点回顾

  1. 核心机制:密集连接通过torch.cat将每一层输出与前面所有层的特征拼接,实现特征的高效复用,这是 DenseNet 精度高的关键;
  2. 瓶颈层作用:1×1 卷积先降维(到 4×k)再用 3×3 卷积,避免密集连接导致的通道数爆炸,大幅减少计算量;
  3. 过渡层作用:通过 1×1 卷积压缩通道(×0.5)和 2×2 池化下采样,控制模型整体复杂度;
  4. 参数配置
    • growth_rate(k):每一层新增通道数,k=32 是 DenseNet-121 的标准配置;
    • block_config:4 个 Dense 块的层数,(6,12,24,16) 对应 121 层;
  5. 优缺点:精度高、参数量少,但训练时内存占用大(需存储大量中间特征),推理速度稍慢。

通过这段代码,你能亲手实现这个 "特征融合大师",感受密集连接带来的强大特征表达能力!


文章转载自:

http://dW65csaM.fbdtd.cn
http://RcbKYgXJ.fbdtd.cn
http://3Q7TPvSK.fbdtd.cn
http://BTYkJK06.fbdtd.cn
http://tv3LNc6g.fbdtd.cn
http://HC8xfiGs.fbdtd.cn
http://SoNsvU2j.fbdtd.cn
http://IMricRdL.fbdtd.cn
http://pqtON6Qw.fbdtd.cn
http://7wUTwCjA.fbdtd.cn
http://2vMwGyUM.fbdtd.cn
http://4MPVIp3n.fbdtd.cn
http://7hvVAddI.fbdtd.cn
http://7lqWaFuM.fbdtd.cn
http://s3FSFsbb.fbdtd.cn
http://isOQ70E2.fbdtd.cn
http://yFIub1PP.fbdtd.cn
http://jGzD4CT2.fbdtd.cn
http://ujLM4PIq.fbdtd.cn
http://oo2rPTZv.fbdtd.cn
http://lBimiq4T.fbdtd.cn
http://hd1jH6Xb.fbdtd.cn
http://65GKLDe7.fbdtd.cn
http://jzr21bCp.fbdtd.cn
http://y28Aff3d.fbdtd.cn
http://kQFfLd1f.fbdtd.cn
http://koRzfypD.fbdtd.cn
http://Ywi9Kc2U.fbdtd.cn
http://LW8j93Lv.fbdtd.cn
http://ALpywsTI.fbdtd.cn
http://www.dtcms.com/a/363318.html

相关文章:

  • Spring Boot手写10万敏感词检查程序
  • C#----异步编程
  • 基于Django的论坛系统设计与实现(代码+数据库+LW)
  • Qt模型/视图编程详解:QStringListModel与多视图数据同步
  • 链表题类型注解解惑:理解Optional,理解ListNode
  • 前端实现解析【导入】数据后调用批量处理接口
  • GaussDB 等待事件为LockMgrLock处理方法
  • 为什么程序员总是发现不了自己的Bug?
  • flutter踩坑插件:Swift架构不兼容
  • 疯狂星期四文案网第58天运营日记
  • 手撕Redis底层2-网络模型深度剖析
  • 【3D 入门-4】trimesh 极速上手之 3D Mesh 数据结构解析(Vertices / Faces)
  • Valkey vs Redis详解
  • 基于若依框架开发WebSocket接口
  • 计算机Python毕业设计推荐:基于Django+Vue用户评论挖掘旅游系统
  • 【交易系统系列36】揭秘币安(Binance)技术心脏:从公开信息拼凑“MatchBox”撮合引擎架构
  • 海康摄像头开发---标准配置结构体(NET_DVR_STD_CONFIG)
  • End-To-End 之于推荐-kuaishou OneRec2 笔记
  • css中 ,有哪些⽅式可以隐藏页⾯元素? 区别?
  • 03_网关ip和端口映射(路由器转发)操作和原理
  • Telnet 原理与配置
  • 基于STM32单片机智能家居wifi远程监控系统机智云app设计
  • Replit在线编程工具:支持多语言环境免配置与实时协作,助力编程学习调试与社区项目复用
  • Spring Security的@PreAuthorize注解为什么会知道用户角色?
  • 0902 C++类的匿名对象
  • Nano Banana 复刻分镜,多图结合片刻生成想要的视频
  • 适配第一性原理与分子动力学研究的高性能工作站解析
  • 信息安全各类加密算法解析
  • LDR6600:2C1A适配器协议方案芯片
  • 综合诊断板CAN时间戳稳定性测试报告8.28