当前位置: 首页 > news >正文

python学习打卡day47

DAY 47 注意力热图可视化

昨天代码中注意力热图的部分顺移至今天

知识点回顾:

热力图

作业:对比不同卷积层热图可视化的结果

# 可视化空间注意力热力图(显示模型关注的图像区域)
def visualize_attention_map(model, test_loader, device, class_names, num_samples=3):"""可视化模型的注意力热力图,展示模型关注的图像区域"""model.eval()  # 设置为评估模式with torch.no_grad():for i, (images, labels) in enumerate(test_loader):if i >= num_samples:  # 只可视化前几个样本breakimages, labels = images.to(device), labels.to(device)# 创建一个钩子,捕获中间特征图activation_maps = []def hook(module, input, output):activation_maps.append(output.cpu())# 为最后一个卷积层注册钩子(获取特征图)hook_handle = model.conv3.register_forward_hook(hook)# 前向传播,触发钩子outputs = model(images)# 移除钩子hook_handle.remove()# 获取预测结果_, predicted = torch.max(outputs, 1)# 获取原始图像img = images[0].cpu().permute(1, 2, 0).numpy()# 反标准化处理img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)img = np.clip(img, 0, 1)# 获取激活图(最后一个卷积层的输出)feature_map = activation_maps[0][0].cpu()  # 取第一个样本# 计算通道注意力权重(使用SE模块的全局平均池化)channel_weights = torch.mean(feature_map, dim=(1, 2))  # [C]# 按权重对通道排序sorted_indices = torch.argsort(channel_weights, descending=True)# 创建子图fig, axes = plt.subplots(1, 4, figsize=(16, 4))# 显示原始图像axes[0].imshow(img)axes[0].set_title(f'原始图像\n真实: {class_names[labels[0]]}\n预测: {class_names[predicted[0]]}')axes[0].axis('off')# 显示前3个最活跃通道的热力图for j in range(3):channel_idx = sorted_indices[j]# 获取对应通道的特征图channel_map = feature_map[channel_idx].numpy()# 归一化到[0,1]channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min() + 1e-8)# 调整热力图大小以匹配原始图像from scipy.ndimage import zoomheatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2]))# 显示热力图axes[j+1].imshow(img)axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet')axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}')axes[j+1].axis('off')plt.tight_layout()plt.show()# 调用可视化函数
visualize_attention_map(model, test_loader, device, class_names, num_samples=3)

@浙大疏锦行

相关文章:

  • PCDF (Progressive Continuous Discrimination Filter)模块构建
  • 基于深度学习的金枪鱼各类别目标检测含完整数据集
  • 如何配置 MySQL 允许远程连接
  • 从内存角度透视现代C++关键特性
  • 一些因子的解释
  • Python控制台输出彩色字体指南
  • Playwright自动化测试全栈指南:从基础到企业级实践(2025终极版)
  • Redis :String类型
  • iOS 门店营收表格功能的实现
  • 《Vuejs设计与实现》第 8 章(挂载与更新)
  • SUSE Linux 发行版全面解析:从开源先驱到企业级支柱
  • 青少年编程与数学 01-011 系统软件简介 07 iOS操作系统
  • Srping Cloud Gateway 跨域配置 CorsWebFilter
  • # 主流大语言模型安全性测试(二):英文越狱提示词下的表现与分析
  • C# 类和继承(扩展方法)
  • 【基础算法】枚举(普通枚举、二进制枚举)
  • redis分片集群架构
  • Python60日基础学习打卡Day46
  • 物联网协议之MQTT(二)服务端
  • Qt Test功能及架构
  • 江西建设工程信息网站/口碑营销案例简短
  • 微管家平台/武汉seo关键词排名优化
  • 网站记录登录账号怎么做/网页链接制作生成
  • 自己建设网站的利弊/百度知道在线问答
  • 做一个电商网站多少钱/高质量内容的重要性
  • 百度网站前三名权重一般在多少/自己建网站需要钱吗