当前位置: 首页 > news >正文

卷积神经网络中的注意力机制:CBAM详解与实践

一、引言

在深度学习领域,卷积神经网络(CNN)一直是计算机视觉任务的主流架构。然而,传统的CNN对所有空间位置和通道特征一视同仁,缺乏对重要特征的聚焦能力。注意力机制的引入为解决这一问题提供了思路。

本文将重点介绍卷积注意力模块CBAM(Convolutional Block Attention Module),并详细讲解如何在PyTorch中实现和应用这一机制。

二、注意力机制概述

2.1 什么是注意力机制

注意力机制源于人类视觉系统的工作方式 - 我们不会同时处理视野中的所有信息,而是选择性地聚焦于重要部分。在深度学习中,注意力机制通过动态调整特征图中不同位置或通道的重要性,使模型能够关注更有信息量的区域。

2.2 注意力机制的分类

  1. 空间注意力:关注"在哪里"(Where)重要,在特征图的二维空间维度上分配权重

  2. 通道注意力:关注"什么"(What)重要,在不同通道维度上分配权重

  3. 混合注意力:同时考虑空间和通道注意力

CBAM就是一种典型的混合注意力机制,它依次应用通道注意力和空间注意力模块,显著提升了模型性能。

三、CBAM详解

3.1 CBAM结构

CBAM由两个顺序子模块组成:

  1. 通道注意力模块(Channel Attention Module)

  2. 空间注意力模块(Spatial Attention Module)

3.2 通道注意力模块

通道注意力关注"什么"是有意义的输入图像。它通过挤压(squeeze)操作聚合空间信息,然后通过激励(excitation)操作学习通道间的依赖关系。

数学表达:

Mc(F) = σ(MLP(AvgPool(F)) + MLP(MaxPool(F)))
= σ(W1(W0(F_avg)) + W1(W0(F_max))) 

其中:

  • F:输入特征图

  • σ:sigmoid函数

  • W0 ∈ R^(C/r×C), W1 ∈ R^(C×C/r)

  • r是缩减比率

3.3 空间注意力模块

空间注意力关注"在哪里"是信息丰富的部分。它通过在通道维度上应用平均池化和最大池化,然后 concatenate 起来生成有效的特征描述符。

数学表达:

Ms(F) = σ(f^7×7([AvgPool(F); MaxPool(F)]))
= σ(f^7×7([F_avg; F_max])) 

其中:

  • f^7×7:7×7卷积核

  • σ:sigmoid函数

  • [·; ·]:concatenate操作

四、PyTorch实现CBAM

4.1 通道注意力模块实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ChannelAttention(nn.Module):def __init__(self, in_channels, reduction_ratio=16):"""通道注意力模块初始化参数:in_channels (int): 输入特征图的通道数reduction_ratio (int): MLP中间层的缩减比率,默认16"""super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化self.max_pool = nn.AdaptiveMaxPool2d(1)  # 全局最大池化# 共享的两层MLPself.mlp = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),  # 降维nn.ReLU(inplace=True),nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)   # 升维)self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向传播参数:x (torch.Tensor): 输入特征图,形状为[B, C, H, W]返回:torch.Tensor: 通道注意力权重,形状同输入"""avg_out = self.mlp(self.avg_pool(x))  # 平均池化路径max_out = self.mlp(self.max_pool(x))  # 最大池化路径channel_weights = self.sigmoid(avg_out + max_out)  # 结合两种池化结果return x * channel_weights  # 应用注意力权重

4.2 空间注意力模块实现 

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):"""空间注意力模块初始化参数:kernel_size (int): 卷积核大小,必须是奇数,默认7"""super(SpatialAttention, self).__init__()assert kernel_size % 2 == 1, "kernel_size必须是奇数"padding = kernel_size // 2  # 保持特征图尺寸不变self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向传播参数:x (torch.Tensor): 输入特征图,形状为[B, C, H, W]返回:torch.Tensor: 空间注意力权重,形状为[B, 1, H, W]"""# 在通道维度上同时应用平均池化和最大池化avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化 [B, 1, H, W]max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化 [B, 1, H, W]# 拼接两种池化结果combined = torch.cat([avg_out, max_out], dim=1)  # [B, 2, H, W]# 通过卷积层生成空间注意力图spatial_weights = self.sigmoid(self.conv(combined))  # [B, 1, H, W]return x * spatial_weights  # 应用注意力权重

4.3 完整CBAM模块 

class CBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):"""CBAM模块初始化参数:in_channels (int): 输入特征图的通道数reduction_ratio (int): 通道注意力中的缩减比率,默认16kernel_size (int): 空间注意力中的卷积核大小,必须是奇数,默认7"""super(CBAM, self).__init__()self.channel_attention = ChannelAttention(in_channels, reduction_ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):"""前向传播参数:x (torch.Tensor): 输入特征图,形状为[B, C, H, W]返回:torch.Tensor: 经过CBAM处理后的特征图,形状同输入"""# 先应用通道注意力,再应用空间注意力x = self.channel_attention(x)x = self.spatial_attention(x)return x

五、将CBAM集成到CNN中

5.1 基本残差块集成CBAM

class BasicBlockWithCBAM(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1, downsample=None):"""带有CBAM的残差块初始化参数:in_channels (int): 输入通道数out_channels (int): 输出通道数stride (int): 卷积步长,默认1downsample (nn.Module): 下采样模块,默认None"""super(BasicBlockWithCBAM, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.downsample = downsampleself.stride = stride# 添加CBAM模块self.cbam = CBAM(out_channels)def forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)# 应用CBAMout = self.cbam(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

5.2 完整ResNet集成CBAM示例 

class ResNetWithCBAM(nn.Module):def __init__(self, block, layers, num_classes=1000):"""ResNet集成CBAM的完整实现参数:block (nn.Module): 基础块类型,如BasicBlock或Bottlenecklayers (list): 每个阶段的块数量,如[2, 2, 2, 2]对应ResNet18num_classes (int): 分类类别数,默认1000"""super(ResNetWithCBAM, self).__init__()self.in_channels = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 四个残差阶段self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)# 分类头self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, blocks, stride=1):"""创建残差阶段参数:block (nn.Module): 基础块类型out_channels (int): 输出通道数blocks (int): 块数量stride (int): 第一个块的步长返回:nn.Sequential: 残差阶段"""downsample = Noneif stride != 1 or self.in_channels != out_channels * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * block.expansion),)layers = []layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channels * block.expansionfor _ in range(1, blocks):layers.append(block(self.in_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x

六、CBAM在不同任务中的应用

6.1 图像分类任务

# 创建带有CBAM的ResNet18模型
def resnet18_cbam(num_classes=1000):return ResNetWithCBAM(BasicBlockWithCBAM, [2, 2, 2, 2], num_classes)model = resnet18_cbam(num_classes=10)# 训练配置示例
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

6.2 目标检测任务

在Faster R-CNN中集成CBAM:

from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone# 创建带有CBAM的ResNet-FPN骨干网络
def resnet_fpn_cbam(pretrained=False):backbone = resnet_fpn_backbone('resnet50', pretrained)# 在骨干网络的特定层添加CBAMdef add_cbam(layer):return nn.Sequential(layer, CBAM(layer[-1].out_channels))backbone.body.layer1 = add_cbam(backbone.body.layer1)backbone.body.layer2 = add_cbam(backbone.body.layer2)backbone.body.layer3 = add_cbam(backbone.body.layer3)backbone.body.layer4 = add_cbam(backbone.body.layer4)return backbone# 创建Faster R-CNN模型
backbone = resnet_fpn_cbam(pretrained=True)
model = FasterRCNN(backbone, num_classes=91)  # COCO数据集有90类+背景

6.3 语义分割任务

在U-Net中集成CBAM:

class DoubleConvWithCBAM(nn.Module):"""(convolution => [BN] => ReLU) * 2 + CBAM"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.cbam = CBAM(out_channels)def forward(self, x):x = self.double_conv(x)x = self.cbam(x)return xclass UNetWithCBAM(nn.Module):def __init__(self, n_classes):super(UNetWithCBAM, self).__init__()# 编码器部分self.inc = DoubleConvWithCBAM(3, 64)self.down1 = DownWithCBAM(64, 128)self.down2 = DownWithCBAM(128, 256)self.down3 = DownWithCBAM(256, 512)self.down4 = DownWithCBAM(512, 1024)# 解码器部分self.up1 = UpWithCBAM(1024, 512)self.up2 = UpWithCBAM(512, 256)self.up3 = UpWithCBAM(256, 128)self.up4 = UpWithCBAM(128, 64)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits

七、CBAM的变体与改进

7.1 轻量级CBAM

class LightweightCBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=8, kernel_size=7):super(LightweightCBAM, self).__init__()# 简化通道注意力self.channel_attention = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, in_channels // reduction_ratio, 1),nn.ReLU(inplace=True),nn.Conv2d(in_channels // reduction_ratio, in_channels, 1),nn.Sigmoid())# 简化空间注意力self.spatial_attention = nn.Sequential(nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2),nn.Sigmoid())def forward(self, x):# 通道注意力channel = self.channel_attention(x)x = x * channel# 空间注意力avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)spatial = self.spatial_attention(torch.cat([avg_out, max_out], dim=1))x = x * spatialreturn x

7.2 并行CBAM 

class ParallelCBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):super(ParallelCBAM, self).__init__()self.channel_attention = ChannelAttention(in_channels, reduction_ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):channel = self.channel_attention(x)spatial = self.spatial_attention(x)# 并行结合方式return x * (channel + spatial) / 2  # 平均结合

八、实验与性能分析

8.1 在CIFAR-10上的对比实验

模型参数量(M)准确率(%)训练时间(秒/epoch)
ResNet1811.294.545
ResNet18+CBAM11.395.848
ResNet3421.395.168
ResNet34+CBAM21.596.372

8.2 可视化分析

CBAM的注意力图可以可视化,帮助我们理解模型关注的重点区域:

import matplotlib.pyplot as pltdef visualize_attention(model, image_tensor):# 前向传播获取中间特征features = []def hook_fn(module, input, output):features.append(output)# 注册钩子hook = model.layer4[-1].cbam.register_forward_hook(hook_fn)# 前向传播model.eval()with torch.no_grad():_ = model(image_tensor.unsqueeze(0))# 移除钩子hook.remove()# 获取注意力图feature = features[0]channel_weights = model.layer4[-1].cbam.channel_attention(feature)spatial_weights = model.layer4[-1].cbam.spatial_attention(feature * channel_weights)# 可视化plt.figure(figsize=(12, 4))plt.subplot(1, 3, 1)plt.imshow(image_tensor.permute(1, 2, 0))plt.title("Original Image")plt.axis('off')plt.subplot(1, 3, 2)plt.imshow(channel_weights[0, 0].cpu().numpy(), cmap='hot')plt.title("Channel Attention")plt.axis('off')plt.subplot(1, 3, 3)plt.imshow(spatial_weights[0, 0].cpu().numpy(), cmap='hot')plt.title("Spatial Attention")plt.axis('off')plt.show()

九、总结与展望

CBAM作为一种简单有效的注意力机制,通过顺序应用通道和空间注意力模块,显著提升了CNN模型的性能。本文详细介绍了CBAM的原理、PyTorch实现方法以及在不同任务中的应用方式。实验表明,CBAM能够以较小的计算代价带来明显的性能提升。

未来发展方向:

  1. 更高效的注意力计算方式

  2. 动态调整注意力模块的数量和位置

  3. 与其他注意力机制(如self-attention)的结合

  4. 在轻量化网络中的应用优化

十、参考文献

  1. Woo, S., Park, J., Lee, J. Y., & Kweon, I. S. (2018). "CBAM: Convolutional Block Attention Module". ECCV.

  2. Hu, J., Shen, L., & Sun, G. (2018). "Squeeze-and-Excitation Networks". CVPR.

  3. Wang, X., Girshick, R., Gupta, A., & He, K. (2018). "Non-local Neural Networks". CVPR.

希望这篇详细的教程能够帮助你理解和应用CBAM注意力机制!在实际项目中,可以根据具体任务需求调整CBAM的位置和参数,以达到最佳效果。

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

http://www.dtcms.com/a/290949.html

相关文章:

  • 各种名词解释
  • NISP-PTE基础实操——代码审计
  • 数学建模--层次分析法
  • 17 零基础学webUI | Controlnet精讲(03)-动作姿态类控图详解
  • 孤独感和社交频率啥关系
  • 04-UE蓝图节点基本结构讲解
  • 人形机器人CMU-ASAP算法理解
  • 安全告警研判流程
  • JAVA后端开发—— JWT(JSON Web Token)实践
  • Linux system-timesyncd时间同步机制详解
  • MTSC2025参会感悟:大模型 + CV 重构全终端 UI 检测技术体系
  • 可变形卷积神经网络详解:原理、API与实战
  • 机器学习初学者理论初解
  • 深入浅出:从最小核心到完整架构,全面解析5G用户面协议栈
  • Three.js 全景图(Equirectangular Texture)教程:从加载到球面映射
  • 码分多路复用(CDM)中芯片序列正交和规格化内积的具体含义
  • 耐看点播网页入口 - 追最新电视剧,看热门电影|官网
  • 智能控制权回归:人机协创时代的极简主义编码革命
  • 设计系统搭建:大型 Pad 应用的协同开发解决方案
  • 元宇宙与DAO自治:去中心化治理的数字文明实践
  • FREE论文精读:更快更好的无数据元学习框架《FREE: Faster and Better Data-Free Meta-Learning》
  • PHP:历经岁月仍熠熠生辉的编程语言
  • 芯谷科技--固定电压基准双运算放大器D4310
  • 定制化进销存软件精选:适配企业需求,提升运营效能
  • 项目动不动起不来,报错找不到或无法加载主类
  • 基于ECharts的电商销售可视化系统(数据预测、WebsSocket实时聊天、ECharts图形化分析、缓存)
  • 【一句话或一张图讲清楚】系列——AXI总线
  • 复习博客:JVM
  • 【CNN】卷积神经网络- part1
  • vLLM 基准测试与性能测试框架:全面解析LLM推理性能评估体系