深度学习:DenseNet 稠密连接 -- 缓解梯度消失
在传统的卷积神经网络中,各层之间采用顺序连接方式,第lll层的输入仅来自第l−1l-1l−1层的输出。这种连接方式存在两个主要问题:
- 梯度消失:在深层网络中,梯度在反向传播过程中逐渐变小,导致前面层的权重更新缓慢
- 特征重用不足:网络难以保留和重用前面层提取的特征信息
2017年,康奈尔大学、清华大学和Facebook的研究团队提出了DenseNet(Dense Convolutional Network),这一架构通过全新的连接模式彻底改变了我们对神经网络连接方式的理解。DenseNet通过引入稠密连接(Dense Connectivity)解决了这些问题。在DenseNet中,每一层都直接连接到所有后续层,形成了密集的连接模式。
一、DenseNet介绍
1.1 DenseNet公式
xl=Hl([x0,x1,...,xl−1])x_l = H_l([x_0, x_1, ..., x_{l-1}])xl=Hl([x0,x1,...,xl−1])
其中:
- xlx_lxl表示第lll层的输出
- [x0,x1,...,xl−1][x_0, x_1, ..., x_{l-1}][x0,x1,...,xl−1]表示将前面所有层的输出在通道维度上进行拼接
- HlH_lHl表示第lll层的非线性变换(通常包括
BN
、ReLU
和Conv
)
1.2 计算示例
假设我们有一个简单的输入矩阵 x:
x=[1234]x = \begin{bmatrix}
1 & 2 \\
3 & 4
\end{bmatrix}x=[1324]
我们考虑一个具有两个密集层的Dense块,每层产生一个特征图(growth_rate=1
)。
- 第一密集层计算
第一层接收输入 xxx,进行变换后产生新特征 f1f_1f1:
f1=H1(x)=σ(W1⋅x+b1)f_1 = H_1(x) = \sigma(W_1 \cdot x + b_1)f1=H1(x)=σ(W1⋅x+b1)
假设权重和偏置为:
W1=[0.50.50.50.5]W_1 = \begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix}W1=[0.50.50.50.5]
b1=[0.50.5]b_1 = \begin{bmatrix} 0.5 \\ 0.5 \end{bmatrix}b1=[0.50.5]
σ\sigmaσ 为ReLU激活函数
计算过程:
f1=ReLU([0.50.50.50.5]⋅[12]+[0.50.5])=ReLU([0.5⋅1+0.5⋅2+0.50.5⋅1+0.5⋅2+0.5])=ReLU([2.02.0])=[2.02.0]f_1 = \text{ReLU}\left(\begin{bmatrix} 0.5 & 0.5 \\ 0.5 & 0.5 \end{bmatrix} \cdot \begin{bmatrix} 1 \\ 2 \end{bmatrix} + \begin{bmatrix} 0.5 \\ 0.5 \end{bmatrix}\right) = \text{ReLU}\left(\begin{bmatrix} 0.5 \cdot 1 + 0.5 \cdot 2 + 0.5 \\ 0.5 \cdot 1 + 0.5 \cdot 2 + 0.5 \end{bmatrix}\right) = \text{ReLU}\left(\begin{bmatrix} 2.0 \\ 2.0 \end{bmatrix}\right) = \begin{bmatrix} 2.0 \\ 2.0 \end{bmatrix}f1=ReLU([0.50.50.50.5]⋅[12]+[0.50.5])=ReLU([0.5⋅1+0.5⋅2+0.50.5⋅1+0.5⋅2+0.5])=ReLU([2.02.0])=[2.02.0]
- 第二密集层计算
第二层接收前面所有层的输出(xxx 和 f1f_1f1)的拼接作为输入:
f2=H2([x,f1])=σ(W2⋅[x,f1]+b2)f_2 = H_2([x, f_1]) = \sigma(W_2 \cdot [x, f_1] + b_2)f2=H2([x,f1])=σ(W2⋅[x,f1]+b2)
假设权重和偏置为:
W2=[0.30.30.40.4]W_2 = \begin{bmatrix} 0.3 & 0.3 & 0.4 & 0.4 \end{bmatrix}W2=[0.30.30.40.4](因为输入是2×2矩阵,拼接后为4维向量)
b2=[−0.5]b_2 = \begin{bmatrix} -0.5 \end{bmatrix}b2=[−0.5]
首先拼接输入:
[x,f1]=[12342.02.0]T[x, f_1] = \begin{bmatrix} 1 \\ 2 \\ 3 \\ 4 \\ 2.0 \\ 2.0 \end{bmatrix}^T \quad[x,f1]=12342.02.0T
(实际计算中会保持矩阵形式,这里展平以简化计算)
计算过程:
f2=ReLU([0.30.30.40.4]⋅[122.02.0]+[−0.5])=ReLU(0.3⋅1+0.3⋅2+0.4⋅2.0+0.4⋅2.0−0.5)=ReLU(0.3+0.6+0.8+0.8−0.5)=ReLU(2.0)=2.0f_2 = \text{ReLU}\left(\begin{bmatrix} 0.3 & 0.3 & 0.4 & 0.4 \end{bmatrix} \cdot \begin{bmatrix} 1 \\ 2 \\ 2.0 \\ 2.0 \end{bmatrix} + \begin{bmatrix} -0.5 \end{bmatrix}\right) = \text{ReLU}\left(0.3 \cdot 1 + 0.3 \cdot 2 + 0.4 \cdot 2.0 + 0.4 \cdot 2.0 - 0.5\right)
= \text{ReLU}(0.3 + 0.6 + 0.8 + 0.8 - 0.5) = \text{ReLU}(2.0) = 2.0f2=ReLU[0.30.30.40.4]⋅122.02.0+[−0.5]=ReLU(0.3⋅1+0.3⋅2+0.4⋅2.0+0.4⋅2.0−0.5)=ReLU(0.3+0.6+0.8+0.8−0.5)=ReLU(2.0)=2.0
- 输出计算
Dense块的输出是所有层输出的拼接:
out=[x,f1,f2]=[12342.02.02.0]T\text{out} = [x, f_1, f_2] = \begin{bmatrix} 1 & 2 & 3 & 4 & 2.0 & 2.0 & 2.0 \end{bmatrix}^Tout=[x,f1,f2]=[12342.02.02.0]T
1.3 DenseNet的架构组成
DenseLayer 类
- 功能:实现一个密集层,包含两个卷积层和批量归一化。
- 输入:
in_channels
是输入通道数,growth_rate
是每个层输出的通道数。 - 输出:将输入和输出在通道维度上拼接,形成新的特征图。
DenseBlock 类
- 功能:实现一个密集块,由多个
DenseLayer
组成。 - 输入:
in_channels
是输入通道数,num_layers
是该块中的层数,growth_rate
是每层的增长率。 - 输出:经过所有层的处理后返回新的特征图。
TransitionLayer 类
- 功能:实现一个过渡层,用于减少特征图的尺寸。
- 输入:
in_channels
是输入通道数,out_channels
是输出通道数。 - 输出:经过卷积和池化操作后的特征图。
DenseNet 类
- 功能:实现整个 DenseNet 模型。
- 参数:
growth_rate
:每个密集层的增长率。block_config
:每个密集块的层数配置。num_classes
:输出类别数。
- 输出:经过特征提取和分类层后的输出。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass DenseLayer(nn.Module):def __init__(self, in_channels, growth_rate):super(DenseLayer, self).__init__()# 批量归一化 -> ReLU -> 1x1卷积 -> 批量归一化 -> ReLU -> 3x3卷积self.bn1 = nn.BatchNorm2d(in_channels)self.conv1 = nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(4 * growth_rate)self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)def forward(self, x):out = self.conv1(F.relu(self.bn1(x)))out = self.conv2(F.relu(self.bn2(out)))return torch.cat([x, out], 1) # 在通道维度上拼接class DenseBlock(nn.Module):def __init__(self, in_channels, num_layers, growth_rate):super(DenseBlock, self).__init__()self.layers = nn.ModuleList()for i in range(num_layers):self.layers.append(DenseLayer(in_channels + i * growth_rate, growth_rate))def forward(self, x):for layer in self.layers:x = layer(x)return xclass TransitionLayer(nn.Module):def __init__(self, in_channels, out_channels):super(TransitionLayer, self).__init__()self.bn = nn.BatchNorm2d(in_channels)self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)def forward(self, x):x = self.conv(F.relu(self.bn(x)))x = self.pool(x)return xclass DenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000):super(DenseNet, self).__init__()# 初始卷积层self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))# 构建密集块和过渡层in_channels = 64self.blocks = nn.ModuleList()for i, num_layers in enumerate(block_config):# 添加密集块block = DenseBlock(in_channels, num_layers, growth_rate)self.blocks.append(block)in_channels += num_layers * growth_rate# 如果不是最后一个块,添加过渡层if i != len(block_config) - 1:trans = TransitionLayer(in_channels, in_channels // 2)self.blocks.append(trans)in_channels = in_channels // 2# 最终分类层self.classifier = nn.Linear(in_channels, num_classes)def forward(self, x):x = self.features(x)for block in self.blocks:x = block(x)x = F.adaptive_avg_pool2d(x, (1, 1))x = torch.flatten(x, 1)x = self.classifier(x)return xif __name__ == "__main__":# 创建一个输入张量,形状为 (batch_size, channels, height, width)input_tensor = torch.randn(1, 3, 64, 64) # 1个样本,3个通道,64x64的图像# 创建 DenseNet 模型实例model = DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), num_classes=10)# 前向传播output_tensor = model(input_tensor)print(f"Input shape: {input_tensor.shape}")print(f"Output shape: {output_tensor.shape}")
Input shape: torch.Size([1, 3, 64, 64])
Output shape: torch.Size([1, 10])
二、代码示例
使用DenseNet 网络,实现手写数字识别任务(MNIST 数据集)。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt# 定义 DenseNet 模型
class DenseLayer(nn.Module):def __init__(self, in_channels, growth_rate):super(DenseLayer, self).__init__()self.bn1 = nn.BatchNorm2d(in_channels)self.conv1 = nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False)self.bn2 = nn.BatchNorm2d(4 * growth_rate)self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)def forward(self, x):out = self.conv1(F.relu(self.bn1(x)))out = self.conv2(F.relu(self.bn2(out)))return torch.cat([x, out], 1) # 在通道维度上拼接class DenseBlock(nn.Module):def __init__(self, in_channels, num_layers, growth_rate):super(DenseBlock, self).__init__()self.layers = nn.ModuleList()for i in range(num_layers):self.layers.append(DenseLayer(in_channels + i * growth_rate, growth_rate))def forward(self, x):for layer in self.layers:x = layer(x)return xclass TransitionLayer(nn.Module):def __init__(self, in_channels, out_channels):super(TransitionLayer, self).__init__()self.bn = nn.BatchNorm2d(in_channels)self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)def forward(self, x):x = self.conv(F.relu(self.bn(x)))x = self.pool(x)return xclass DenseNet(nn.Module):def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=10):super(DenseNet, self).__init__()# 初始卷积层self.features = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))# 构建密集块和过渡层in_channels = 64self.blocks = nn.ModuleList()for i, num_layers in enumerate(block_config):# 添加密集块block = DenseBlock(in_channels, num_layers, growth_rate)self.blocks.append(block)in_channels += num_layers * growth_rate# 如果不是最后一个块,添加过渡层if i != len(block_config) - 1:trans = TransitionLayer(in_channels, in_channels // 2)self.blocks.append(trans)in_channels = in_channels // 2# 最终分类层self.classifier = nn.Linear(in_channels, num_classes)def forward(self, x):x = self.features(x)for block in self.blocks:x = block(x)x = F.adaptive_avg_pool2d(x, (1, 1))x = torch.flatten(x, 1)x = self.classifier(x)return x# 数据预处理和加载
transform = transforms.Compose([transforms.Resize((64, 64)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),
])# MNIST 数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 训练模型
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') # 使用 MPS 训练
model = DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0for i, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.to(device)# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()if (i + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')print(f'Epoch [{epoch + 1}/{num_epochs}], Average Loss: {running_loss / len(train_loader):.4f}')# 测试模型
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')# 可视化部分测试结果
def visualize_results(test_loader, model, num_images=5):model.eval()images, labels = next(iter(test_loader))images, labels = images[:num_images].to(device), labels[:num_images].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 可视化plt.figure(figsize=(10, 2))for i in range(num_images):plt.subplot(1, num_images, i + 1)plt.imshow(images[i].cpu().squeeze(), cmap='gray')plt.title(f'Pred: {predicted[i].item()}\nTrue: {labels[i].item()}')plt.axis('off')plt.show()# 可视化测试结果
visualize_results(test_loader, model)
Accuracy of the model on the test images: 99.03%