当前位置: 首页 > news >正文

8.20 打卡 DAY 47 注意力热图可视化

DAY 47: 注意力热图可视化——让模型的决策“看得见”

欢迎来到第47天的学习!在昨天的课程中,我们为CNN模型引入了通道注意力 (Channel Attention) 机制,并看到了它对模型性能的提升。今天,我们将聚焦于一项强大的可视化技术——注意力热力图 (Attention Heatmap),深入探讨它的原理、代码实现,并分析如何通过它来揭示模型的“思考过程”。

这项技术不仅能帮助我们调试模型,更能让我们直观地理解模型在做出决策时,究竟关注了图像的哪些部分。

1. 知识点回顾:什么是热力图 (Heatmap)?

在我们深入代码之前,先来回顾一下什么是热力图。

热力图是一种数据可视化技术,它通过颜色来表示数据的强度或密度。通常,我们用暖色调(如红色、黄色)代表数值较高或密度较大的区域,用冷色调(如蓝色、绿色)代表数值较低的区域。

在深度学习中,热力图可以用来可视化:

  1. 特征图的激活强度:在卷积神经网络中,特征图上的每个像素值都代表了模型在该位置对某种特定特征(如边缘、纹理、物体部件)的响应强度。将这个二维的特征图用颜色渲染,就形成了一张热力图,高亮区域(红色)即为模型认为该特征最显著的位置。
  2. 注意力权重:对于注意力机制,我们可以将学习到的空间权重或通道权重映射回原始图像上,生成热力图。红色区域就代表模型“注意力”最集中的地方。
  3. 类别激活映射 (Class Activation Mapping, CAM/Grad-CAM):这是一种更高级的技术,它能生成一张热力图,明确指示出模型在判断某个特定类别时,主要依据了图像的哪些区域。例如,在判断图片为“猫”时,热力图会在猫的轮廓上呈现红色。

我们今天的注意力热力图,本质上就是第一种和第二种的结合:我们找出被通道注意力认为最重要的几个特征通道,然后将这些通道的特征激活图以热力图的形式叠加在原图上进行可视化。


2. 注意力热力图可视化代码详解

我们将详细解析昨天代码中用于生成注意力热力图的函数 visualize_attention_map

目标:对于一张输入的测试图片,找出其在最后一个卷积层中,被模型最关注的几个特征通道,并将这些通道的激活图以热力图的形式展示出来。

# (函数定义)
def visualize_attention_map(model, test_loader, device, class_names, num_samples=3):model.eval() # 切换到评估模式with torch.no_grad():for i, (images, labels) in enumerate(test_loader):if i >= num_samples: breakimages, labels = images.to(device), labels.to(device)# --- 步骤1: 注册钩子以捕获特征图 ---activation_maps = []def hook(module, input, output):activation_maps.append(output.cpu())# 目标层:最后一个卷积层hook_handle = model.conv3.register_forward_hook(hook)# --- 步骤2: 前向传播 ---outputs = model(images)hook_handle.remove() # 完成后立即移除钩子_, predicted = torch.max(outputs, 1)# --- 步骤3: 准备原始图像和特征图 ---# (反标准化代码,将图像恢复为原始像素值)img = ... # 取出第一个样本的特征图feature_map = activation_maps[0][0].cpu()# --- 步骤4: 计算通道权重并生成热力图 ---# 全局平均池化,得到每个通道的“重要性”分数channel_weights = torch.mean(feature_map, dim=(1, 2))# 找出权重最高的通道索引sorted_indices = torch.argsort(channel_weights, descending=True)# --- 步骤5: 可视化 ---fig, axes = plt.subplots(1, 4, figsize=(16, 4))# (显示原始图像)axes[0].imshow(img)# 显示权重最高的3个通道的热力图for j in range(3):channel_idx = sorted_indices[j]channel_map = feature_map[channel_idx].numpy()# (归一化到0-1范围)channel_map = ...# 上采样热力图以匹配原图尺寸from scipy.ndimage import zoomheatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2]))# 叠加显示axes[j+1].imshow(img)axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet')axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}')plt.show()

代码逻辑分解:

  1. 注册钩子 (register_forward_hook):我们依然使用钩子函数来“拦截”模型内部的数据。这次我们选择拦截conv3(最后一个卷积层)的输出,因为这一层的特征图包含最高级的语义信息。
  2. 前向传播:执行model(images),当数据流经conv3层后,我们的hook函数会被触发,将conv3的输出特征图保存到activation_maps列表中。
  3. 计算通道权重torch.mean(feature_map, dim=(1, 2)) 是一个关键步骤。它对每个通道的特征图(尺寸为H x W)进行全局平均池化,得到一个单一的数值。这个数值可以看作是该通道对整张图片的平均激活强度。激活强度越高的通道,通常代表它提取到的特征对最终分类越重要。
  4. 排序 (torch.argsort):通过对权重进行降序排序,我们就能找到那些“最重要”的通道。
  5. 上采样与叠加 (scipy.ndimage.zoom):从重要通道中取出的特征图尺寸很小(例如4x4),无法直接与原始图像(32x32)叠加。zoom函数通过插值算法,将小尺寸的特征图放大到与原始图像相同的尺寸。最后,通过设置alpha透明度,将热力图叠加在原图上。
3. 结果分析

上图展示了对一张青蛙图片的注意力热力图可视化结果。左边是原始图像和模型的预测结果,右边三张图分别是模型认为最重要的三个通道(通道106, 126, 85)的激活热力图。

解读:

  • 红色区域代表高激活值,即模型“注意力”最集中的地方。
  • 我们可以清晰地看到,这三个最重要的通道都准确地将注意力聚焦在了青蛙的身体轮廓上,而背景区域则基本是蓝色(低关注)。
  • 这说明,我们加入的通道注意力模块成功地让模型学会了给不同的特征通道分配权重。在识别这张图片时,模型放大了那些能够提取出“青蛙”特征的通道,而这些通道的激活区域也正好对应了图片中的青蛙本身。

通过这种方式,我们不仅验证了模型的有效性,还让模型不再是一个“黑箱”,它的决策过程变得直观可见。


4. 作业与参考答案

作业:对比不同卷积层(conv1, conv2, conv3)热图可视化的结果。

参考答案:

这个作业旨在让我们更深入地理解CNN特征逐层抽象的过程。要完成这个作业,我们需要修改visualize_attention_map函数,使其能够同时捕获和显示conv1, conv2, conv3三层的特征图。

修改思路:

  1. layer_names列表中传入['conv1', 'conv2', 'conv3']
  2. 循环为这三个层都注册钩子。
  3. 在可视化部分,为每一层都绘制其权重最高的热力图。

预期结果与分析:

当我们对比不同卷积层的注意力热力图时,会观察到一个非常清晰的模式:特征从具体到抽象,注意力从分散到集中

  • conv1 热力图 (浅层)

    • 特点:热力图的高亮区域(红色)会非常分散,遍布在图像的各个角落,包括青蛙的轮廓和背景中的树叶边缘。
    • 原因conv1负责提取的是边缘、颜色、纹理等低级通用特征。在这一阶段,模型还无法区分哪个是主体、哪个是背景,它只是忠实地标出所有它能识别出的基础特征。
  • conv2 热力图 (中层)

    • 特点:热力图会开始变得聚焦。高亮区域会更多地集中在青蛙的身体、腿部等有意义的局部结构上,而背景区域的激活会减弱。
    • 原因conv2基于conv1的低级特征,开始学习组合成更复杂的中级特征,如物体的部件和形状。模型的注意力开始从“哪里有边缘”转向“哪里有像青蛙身体一部分的形状”。
  • conv3 热力图 (深层)

    • 特点:热力图会变得高度集中,几乎所有的高亮区域都精确地覆盖在青蛙的主体上,背景则完全变成蓝色(低关注)。
    • 原因conv3负责提取的是最抽象的高级语义特征。在这一层,模型已经整合了前面的所有信息,形成了对“青蛙”这个概念的整体认知。它的注意力完全集中在那些最能区分出“青蛙”与其他类别的决定性特征上。

总结对比表格:

层级特征级别注意力焦点抽象程度
conv1低级(边缘、纹理)分散,遍布整个图像低,接近原图细节
conv2中级(物体部件)开始聚焦于主体区域中等,识别局部结构
conv3高级(语义概念)高度集中于主体高,形成整体概念

通过这个对比,我们能生动地看到一个CNN模型是如何一步步地从像素中提炼出语义信息,并最终做出准确分类的。


@浙大疏锦行

http://www.dtcms.com/a/341719.html

相关文章:

  • 不会写 SQL 也能出报表?积木报表 + AI 30 秒自动生成报表和图表
  • JVM讲解
  • leetcode7二分查找_69 and 34
  • Linux正则表达式
  • 2D水平目标检测数据增强——旋转任意指定角度
  • RK3568 Linux驱动学习——设备树下 LED 驱动
  • Redisson最新版本(3.50.0左右)启动时提示Netty的某些类找不到
  • PowerShell脚本检查业务健康状态
  • 解决Docker 无法连接到官方镜像仓库
  • Lecture 6 Kernels, Triton 课程笔记
  • JVM基础知识总结
  • Docker 核心技术:Linux Cgroups
  • GDB 的多线程调试
  • 针对具有下垂控制光伏逆变器的主动配电网络的多目标分层协调电压/无功控制方法的复现
  • 音频读写速度优化 音频格式
  • Transformer内容详解(通透版)
  • pip install -e中e 参数解释
  • 八辊矫平机·第三篇
  • 卸载win10/win11系统里导致磁盘故障的补丁
  • 广东省省考备考(第八十二天8.20)——资料分析、数量、言语(强化训练)
  • 【蒸蒸日上】军八武将篇——标1
  • 8 webUI中-Controlnet(控制与约束)的应用分类与使用方法
  • 【语法】markdown非常用场景
  • Netty HashedWheelTimer设计原理:从时间轮算法到源码实现
  • 跨平台 RTSP/RTMP 播放器工程化实践:低延迟与高稳定性的挑战与突破
  • 【数据分享】东北大鼠疫传播与死亡空间数据
  • Vue透传 Attributes(详细解析)2
  • 恶补DSP:2.F28335的定时器系统
  • 买返商城网站源码多平台购物返现搭建图解源码二开
  • 万象生鲜配送系统 2025 年 8 月 15 日更新日志