day52 ResNet18 CBAM
在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升了模型的性能,还对深度学习中的预训练和微调有了更深刻的理解。
一、背景知识
ResNet18 是一种经典的卷积神经网络架构,广泛应用于图像分类任务。CBAM 是一种注意力机制模块,能够同时关注特征图的通道和空间维度,提升模型对关键特征的关注能力。将 CBAM 模块插入 ResNet18 中,可以增强模型的特征表达能力。
二、研究方法
1. CBAM 模块的插入位置
- CBAM 模块被插入到 ResNet18 的每个残差块(BasicBlock)之后。这样可以在每个特征提取阶段都引入注意力机制,让模型在提取特征的同时学会关注重要的特征。
- CBAM 模块的初始状态接近“直通”,即在训练初期,CBAM 模块对特征图的影响较小,不会破坏预训练模型的权重。
2. 预训练策略
- 阶段 1(Epoch 1-5):仅解冻分类头(fc)和所有 CBAM 模块,冻结 ResNet18 的主干卷积层。目标是让模型快速学习新任务的分类边界,同时让 CBAM 模块找到初步的关注点。学习率设置为 1e-3。
- 阶段 2(Epoch 6-20):解冻高层卷积层(layer3, layer4),保持低层卷积层(layer1, layer2)冻结。目标是让模型的高层特征提取能力适应新任务的抽象概念。学习率设置为 1e-4。
- 阶段 3(Epoch 21-50):解冻所有层,进行端到端微调。目标是让模型的底层特征也与新任务对齐,提升整体性能。学习率设置为 1e-5。
三、实验过程
1. 数据预处理
- 使用 CIFAR-10 数据集,包含 10 个类别的 60,000 张 32x32 的彩色图像。
- 数据增强包括随机裁剪、水平翻转、颜色抖动等。
2. 模型定义
- 定义了 ResNet18_CBAM 模型,继承自 PyTorch 的 nn.Module。
- 在每个残差块后插入 CBAM 模块,调整通道数和空间维度的注意力权重。
3. 训练过程
- 使用 Adam 优化器,动态调整学习率。
- 每个阶段的训练过程都有详细的日志输出,包括每个 batch 的损失和每个 epoch 的训练准确率和测试准确率。
四、关键结论
1. 训练过程中的损失和准确率变化
- 在阶段 1,模型的训练准确率从 37.31% 提升到 49.86%,测试准确率从 47.48% 提升到 54.98%。
- 在阶段 2,模型的训练准确率从 61.34% 提升到 86.26%,测试准确率从 71.71% 提升到 85.99%。
- 在阶段 3,模型的训练准确率从 88.75% 提升到 95.15%,测试准确率从 87.58% 提升到 90.15%。
2. 最终性能
- 经过 50 个 epoch 的训练,模型的最终测试准确率达到了 90.15%。这表明 CBAM 模块显著提升了模型的性能,尤其是在高层特征提取和全局微调阶段。
五、代码实现
以下是 ResNet18_CBAM 模型的定义和训练过程的代码实现:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 定义 CBAM 模块
class ChannelAttention(nn.Module):
def __init__(self, in_channels, ratio=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // ratio, bias=False),
nn.ReLU(),
nn.Linear(in_channels // ratio, in_channels, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.shape
avg_out = self.fc(self.avg_pool(x).view(b, c))
max_out = self.fc(self.max_pool(x).view(b, c))
attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
return x * attention
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
pool_out = torch.cat([avg_out, max_out], dim=1)
attention = self.conv(pool_out)
return x * self.sigmoid(attention)
class CBAM(nn.Module):
def __init__(self, in_channels, ratio=16, kernel_size=7):
super().__init__()
self.channel_attn = ChannelAttention(in_channels, ratio)
self.spatial_attn = SpatialAttention(kernel_size)
def forward(self, x):
x = self.channel_attn(x)
x = self.spatial_attn(x)
return x
# 定义 ResNet18_CBAM 模型
class ResNet18_CBAM(nn.Module):
def __init__(self, num_classes=10, pretrained=True, cbam_ratio=16, cbam_kernel=7):
super().__init__()
self.backbone = models.resnet18(pretrained=pretrained)
self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.backbone.maxpool = nn.Identity()
self.cbam_layer1 = CBAM(in_channels=64, ratio=cbam_ratio, kernel_size=cbam_kernel)
self.cbam_layer2 = CBAM(in_channels=128, ratio=cbam_ratio, kernel_size=cbam_kernel)
self.cbam_layer3 = CBAM(in_channels=256, ratio=cbam_ratio, kernel_size=cbam_kernel)
self.cbam_layer4 = CBAM(in_channels=512, ratio=cbam_ratio, kernel_size=cbam_kernel)
self.backbone.fc = nn.Linear(in_features=512, out_features=num_classes)
def forward(self, x):
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.layer1(x)
x = self.cbam_layer1(x)
x = self.backbone.layer2(x)
x = self.cbam_layer2(x)
x = self.backbone.layer3(x)
x = self.cbam_layer3(x)
x = self.backbone.layer4(x)
x = self.cbam_layer4(x)
x = self.backbone.avgpool(x)
x = torch.flatten(x, 1)
x = self.backbone.fc(x)
return x
# 数据预处理
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 训练函数
def train(model, device, train_loader, optimizer, criterion):
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
if (batch_idx + 1) % 100 == 0:
print(f'Batch: {batch_idx+1}/{len(train_loader)} | 单Batch损失: {loss.item():.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')
epoch_loss = running_loss / len(train_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
# 测试函数
def test(model, device, test_loader, criterion):
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
epoch_loss = test_loss / len(test_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
# 主函数
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
model = ResNet18_CBAM(num_classes=10, pretrained=True).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
epochs = 50
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
test_loss, test_acc = test(model, device, test_loader, criterion)
scheduler.step()
print(f'Epoch {epoch}/{epochs} 完成 | 耗时: {time.time() - epoch_start_time:.2f}s | 训练准确率: {train_acc:.2f}% | 测试准确率: {test_acc:.2f}%')
torch.save(model.state_dict(), 'resnet18_cbam_finetuned.pth')
print("模型已保存为: resnet18_cbam_finetuned.pth")
if __name__ == "__main__":
main()
@浙大疏锦行