当前位置: 首页 > news >正文

python学习打卡day47

DAY 47 注意力热图可视化

昨天代码中注意力热图的部分顺移至今天

知识点回顾:

热力图

作业:对比不同卷积层热图可视化的结果

# 可视化空间注意力热力图(显示模型关注的图像区域)
def visualize_attention_map(model, test_loader, device, class_names, num_samples=3):"""可视化模型的注意力热力图,展示模型关注的图像区域"""model.eval()  # 设置为评估模式with torch.no_grad():for i, (images, labels) in enumerate(test_loader):if i >= num_samples:  # 只可视化前几个样本breakimages, labels = images.to(device), labels.to(device)# 创建一个钩子,捕获中间特征图activation_maps = []def hook(module, input, output):activation_maps.append(output.cpu())# 为最后一个卷积层注册钩子(获取特征图)hook_handle = model.conv3.register_forward_hook(hook)# 前向传播,触发钩子outputs = model(images)# 移除钩子hook_handle.remove()# 获取预测结果_, predicted = torch.max(outputs, 1)# 获取原始图像img = images[0].cpu().permute(1, 2, 0).numpy()# 反标准化处理img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)img = np.clip(img, 0, 1)# 获取激活图(最后一个卷积层的输出)feature_map = activation_maps[0][0].cpu()  # 取第一个样本# 计算通道注意力权重(使用SE模块的全局平均池化)channel_weights = torch.mean(feature_map, dim=(1, 2))  # [C]# 按权重对通道排序sorted_indices = torch.argsort(channel_weights, descending=True)# 创建子图fig, axes = plt.subplots(1, 4, figsize=(16, 4))# 显示原始图像axes[0].imshow(img)axes[0].set_title(f'原始图像\n真实: {class_names[labels[0]]}\n预测: {class_names[predicted[0]]}')axes[0].axis('off')# 显示前3个最活跃通道的热力图for j in range(3):channel_idx = sorted_indices[j]# 获取对应通道的特征图channel_map = feature_map[channel_idx].numpy()# 归一化到[0,1]channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min() + 1e-8)# 调整热力图大小以匹配原始图像from scipy.ndimage import zoomheatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2]))# 显示热力图axes[j+1].imshow(img)axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet')axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}')axes[j+1].axis('off')plt.tight_layout()plt.show()# 调用可视化函数
visualize_attention_map(model, test_loader, device, class_names, num_samples=3)

@浙大疏锦行


文章转载自:

http://Vo9k7Z8a.mzzqs.cn
http://5bTiW5N3.mzzqs.cn
http://R8sWq9E4.mzzqs.cn
http://EQ1moJtZ.mzzqs.cn
http://a4ZanpY6.mzzqs.cn
http://BhYwdQLH.mzzqs.cn
http://Nq66JMJZ.mzzqs.cn
http://CoRiGkLK.mzzqs.cn
http://kBh33Ghz.mzzqs.cn
http://cTi9dQan.mzzqs.cn
http://nnnjZ6s6.mzzqs.cn
http://v32aPSIU.mzzqs.cn
http://Vx5XT61g.mzzqs.cn
http://mdFH5O0U.mzzqs.cn
http://olv3svby.mzzqs.cn
http://7ieZkqnH.mzzqs.cn
http://icVeiDH1.mzzqs.cn
http://Dpi8Gnz7.mzzqs.cn
http://WzE7xut4.mzzqs.cn
http://pi7TxA44.mzzqs.cn
http://bXNjJ7iS.mzzqs.cn
http://5MdTdf3W.mzzqs.cn
http://nw2O0Jkr.mzzqs.cn
http://Cn95gqI6.mzzqs.cn
http://1HyeoGDw.mzzqs.cn
http://fOXlz047.mzzqs.cn
http://kPRUnJc5.mzzqs.cn
http://7KdYvdSA.mzzqs.cn
http://rcN2bk0V.mzzqs.cn
http://H9BUHqFH.mzzqs.cn
http://www.dtcms.com/a/236622.html

相关文章:

  • PCDF (Progressive Continuous Discrimination Filter)模块构建
  • 基于深度学习的金枪鱼各类别目标检测含完整数据集
  • 如何配置 MySQL 允许远程连接
  • 从内存角度透视现代C++关键特性
  • 一些因子的解释
  • Python控制台输出彩色字体指南
  • Playwright自动化测试全栈指南:从基础到企业级实践(2025终极版)
  • Redis :String类型
  • iOS 门店营收表格功能的实现
  • 《Vuejs设计与实现》第 8 章(挂载与更新)
  • SUSE Linux 发行版全面解析:从开源先驱到企业级支柱
  • 青少年编程与数学 01-011 系统软件简介 07 iOS操作系统
  • Srping Cloud Gateway 跨域配置 CorsWebFilter
  • # 主流大语言模型安全性测试(二):英文越狱提示词下的表现与分析
  • C# 类和继承(扩展方法)
  • 【基础算法】枚举(普通枚举、二进制枚举)
  • redis分片集群架构
  • Python60日基础学习打卡Day46
  • 物联网协议之MQTT(二)服务端
  • Qt Test功能及架构
  • Python Cookbook-7.12 在 SQLite 中储存 BLOB
  • 【Java学习笔记】StringBuilder类(重点)
  • 以SMMUv2为例,使用Trace32可视化操作SMMU的常用命令详解
  • stm32内存踩踏一例
  • DeepSeek-R1-0528:开源推理模型的革新与突破
  • AI开发 | 生成式AI在企业软件中的演进形态:从嵌入式到智能体
  • SQL-事务(2025.6.6-2025.6.7学习篇)
  • 零基础玩转物联网-串口转以太网模块如何快速实现与TCP服务器通信
  • Android学习总结-GetX库常见问题和解决方案
  • 安卓基础(Java 和 Gradle 版本)