当前位置: 首页 > news >正文

day49 python 注意力热图

目录

一、注意力热图简介

二、类激活图(CAM)原理

三、基于PyTorch的CAM实现

(一)加载库与模型

(二)输入图片预处理与模型预测

(三)生成注意力热图

(四)可视化与保存结果

四、实验结果与分析


一、注意力热图简介

注意力热图是一种强大的可视化工具,能够直观地展示神经网络在处理输入图像时的关注区域。它可以帮助我们理解模型是如何做出决策的,从而更好地优化和改进模型。在实际应用中,注意力热图广泛应用于图像分类、目标检测等领域,为研究人员提供了宝贵的洞察。

通过阅读大量相关资料,我发现大多数方法都是基于神经网络的输出特征图来生成注意力热图。具体来说,可以使用任意层的特征图,但通常选择最后一个卷积层的输出特征图。将特征图调整到输入图像的大小后,通过特定的函数将其叠加到原图像上,即可得到注意力热图。虽然这个过程看似简单,但在实际操作中,仍有许多细节需要注意。在本次实验中,我主要采用了类激活图(CAM)方法来生成注意力热图。

二、类激活图(CAM)原理

类激活图(CAM)方法是由论文《Learning Deep Features for Discriminative Localization》提出的一种经典方法。其核心思想是利用神经网络的卷积层特征图和分类层权重来生成注意力热图。以下是CAM方法的具体步骤:

  1. 获取输出特征图:从神经网络中提取输出特征图,其形状为[B, C, H, W],其中B为批量大小,C为最后一个卷积层的输出通道数,H和W分别为特征图的宽度和高度。如果输入一张图片,则B=1。

  2. 获取分类层权重:提取训练好的模型的分类头权重classifier.weight,注意分类层的输入通道数必须与输出特征图的通道数匹配。

  3. 加权求和生成注意力热图:将每个通道的特征图与分类层权重进行加权求和,最终得到每一类的注意力热图。

三、基于PyTorch的CAM实现

以下是使用PyTorch实现CAM的完整代码,代码中包含了详细的注释,方便读者理解每个步骤的具体操作。

(一)加载库与模型

import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from PIL import Image# 加载自己的网络
from model import modelclass_num = 5model_ft = model(num_classes=class_num)
model_ft.load_state_dict(torch.load('pretrain.pth', map_location=lambda storage, loc: storage))model_features = nn.Sequential(*list(model_ft.children())[:-2])
fc_weights = model_ft.state_dict()['classifier.weight'].cpu().numpy()
class_ = {0: 'car', 1: 'bird', 2: 'tree', 3: 'sky', 4: 'person'}
model_ft.eval()
model_features.eval()

(二)输入图片预处理与模型预测

data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}img_path = '/data/test.jpg'  # 单张测试
_, img_name = os.path.split(img_path)
features_blobs = []
img = Image.open(img_path).convert('RGB')
img_tensor = data_transform['val'](img).unsqueeze(0)  # [1,3,224,224]
features = model_features(img_tensor).detach().cpu().numpy()  # [1,960,7,7]logit = model_ft(img_tensor)  # [1,2] -> [ 3.3207, -2.9495]
h_x = torch.nn.functional.softmax(logit, dim=1).data.squeeze()  # tensor([0.9981, 0.0019])probs, idx = h_x.sort(0, True)  # 按概率从大到小排列
probs = probs.cpu().numpy()
idx = idx.cpu().numpy()for i in range(class_num):print('{:.3f} -> {}'.format(probs[i], class_[idx[i]]))  # 打印预测结果

(三)生成注意力热图

def returnCAM(feature_conv, weight_softmax, class_idx):bz, nc, h, w = feature_conv.shapeoutput_cam = []for idx in class_idx:feature_conv = feature_conv.reshape((nc, h * w))cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h * w)))cam = cam.reshape(h, w)cam_img = (cam - cam.min()) / (cam.max() - cam.min())cam_img = np.uint8(255 * cam_img)output_cam.append(cam_img)return output_camCAMs = returnCAM(features, fc_weights, idx)  # 输出预测概率最大的特征图集对应的CAM
print(img_name + ' output for the top1 prediction: %s' % class_[idx[0]])

(四)可视化与保存结果

img = cv2.imread(img_path)
height, width, _ = img.shape
heatmap = cv2.applyColorMap(cv2.resize(CAMs[0], (width, height)), cv2.COLORMAP_JET)
result = heatmap * 0.3 + img * 0.5text = '%s %.2f%%' % (class_[idx[0]], probs[0] * 100)
cv2.putText(result, text, (210, 40), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.9,color=(123, 222, 238), thickness=2, lineType=cv2.LINE_AA)CAM_RESULT_PATH = r'/data/heatmap/'
if not os.path.exists(CAM_RESULT_PATH):os.mkdir(CAM_RESULT_PATH)
image_name_ = img_name.split(".")[-2]
cv2.imwrite(os.path.join(CAM_RESULT_PATH, image_name_ + '_heatmap.jpg'), result)

四、实验结果与分析

通过上述代码,我成功生成了输入图像的注意力热图,并将其与原图叠加显示。从结果可以看出,注意力热图清晰地标注出了模型在做出预测时关注的区域。例如,在对“car”类别进行预测时,热图主要集中在车辆的轮廓和关键部位,这表明模型能够准确地识别出车辆的特征区域。

@浙大疏锦行

相关文章:

  • 银行卡二三四要素实名接口如何用PHP实现调用?
  • 抖去推--短视频矩阵系统源码开发
  • 十(1). 强制类型转换
  • 【C/C++】实现固定地址函数调用
  • OSCP靶机练习 mantis
  • FlashAttention 公式推导
  • OD 算法题 B卷【全排列】
  • Supersonic 新一代AI数据分析平台
  • JS有哪些迭代器,该如何使用?
  • 【HarmonyOS 5.0】DevEco Testing:鸿蒙应用质量保障的终极武器
  • vue中的派发事件与广播事件,及广播事件应用于哪些场景和一个表单验证例子
  • 5.4.2 Spring Boot整合Redis
  • oracle 11g ADG备库报错ORA-00449 lgwr unexpectedly分析处理
  • C++刷题:日期模拟(1)
  • react菜单,动态绑定点击事件,菜单分离出去单独的js文件,Ant框架
  • 【Docker 01】Docker 简介
  • 数学:花括号在数学中的应用详解
  • Strong Baseline: Multi-UAV Tracking via YOLOv12 with BoT-SORT-ReID 2025最新无人机跟踪
  • Scrapy爬虫教程(新手)
  • 论文阅读:Matting by Generation
  • 多语言网站怎么实现的/seo优化快速排名
  • 如何做中介网站/网站首页制作网站
  • wordpress怎样加快访问/汽车seo是什么意思
  • wordpress导航站模版/防城港网站seo
  • 无限动力网站/面点培训学校哪里有
  • 什么网站可以做直播/建网站找谁