day49 python 注意力热图
目录
一、注意力热图简介
二、类激活图(CAM)原理
三、基于PyTorch的CAM实现
(一)加载库与模型
(二)输入图片预处理与模型预测
(三)生成注意力热图
(四)可视化与保存结果
四、实验结果与分析
一、注意力热图简介
注意力热图是一种强大的可视化工具,能够直观地展示神经网络在处理输入图像时的关注区域。它可以帮助我们理解模型是如何做出决策的,从而更好地优化和改进模型。在实际应用中,注意力热图广泛应用于图像分类、目标检测等领域,为研究人员提供了宝贵的洞察。
通过阅读大量相关资料,我发现大多数方法都是基于神经网络的输出特征图来生成注意力热图。具体来说,可以使用任意层的特征图,但通常选择最后一个卷积层的输出特征图。将特征图调整到输入图像的大小后,通过特定的函数将其叠加到原图像上,即可得到注意力热图。虽然这个过程看似简单,但在实际操作中,仍有许多细节需要注意。在本次实验中,我主要采用了类激活图(CAM)方法来生成注意力热图。
二、类激活图(CAM)原理
类激活图(CAM)方法是由论文《Learning Deep Features for Discriminative Localization》提出的一种经典方法。其核心思想是利用神经网络的卷积层特征图和分类层权重来生成注意力热图。以下是CAM方法的具体步骤:
-
获取输出特征图:从神经网络中提取输出特征图,其形状为[B, C, H, W],其中B为批量大小,C为最后一个卷积层的输出通道数,H和W分别为特征图的宽度和高度。如果输入一张图片,则B=1。
-
获取分类层权重:提取训练好的模型的分类头权重
classifier.weight
,注意分类层的输入通道数必须与输出特征图的通道数匹配。 -
加权求和生成注意力热图:将每个通道的特征图与分类层权重进行加权求和,最终得到每一类的注意力热图。
三、基于PyTorch的CAM实现
以下是使用PyTorch实现CAM的完整代码,代码中包含了详细的注释,方便读者理解每个步骤的具体操作。
(一)加载库与模型
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from PIL import Image# 加载自己的网络
from model import modelclass_num = 5model_ft = model(num_classes=class_num)
model_ft.load_state_dict(torch.load('pretrain.pth', map_location=lambda storage, loc: storage))model_features = nn.Sequential(*list(model_ft.children())[:-2])
fc_weights = model_ft.state_dict()['classifier.weight'].cpu().numpy()
class_ = {0: 'car', 1: 'bird', 2: 'tree', 3: 'sky', 4: 'person'}
model_ft.eval()
model_features.eval()
(二)输入图片预处理与模型预测
data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}img_path = '/data/test.jpg' # 单张测试
_, img_name = os.path.split(img_path)
features_blobs = []
img = Image.open(img_path).convert('RGB')
img_tensor = data_transform['val'](img).unsqueeze(0) # [1,3,224,224]
features = model_features(img_tensor).detach().cpu().numpy() # [1,960,7,7]logit = model_ft(img_tensor) # [1,2] -> [ 3.3207, -2.9495]
h_x = torch.nn.functional.softmax(logit, dim=1).data.squeeze() # tensor([0.9981, 0.0019])probs, idx = h_x.sort(0, True) # 按概率从大到小排列
probs = probs.cpu().numpy()
idx = idx.cpu().numpy()for i in range(class_num):print('{:.3f} -> {}'.format(probs[i], class_[idx[i]])) # 打印预测结果
(三)生成注意力热图
def returnCAM(feature_conv, weight_softmax, class_idx):bz, nc, h, w = feature_conv.shapeoutput_cam = []for idx in class_idx:feature_conv = feature_conv.reshape((nc, h * w))cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h * w)))cam = cam.reshape(h, w)cam_img = (cam - cam.min()) / (cam.max() - cam.min())cam_img = np.uint8(255 * cam_img)output_cam.append(cam_img)return output_camCAMs = returnCAM(features, fc_weights, idx) # 输出预测概率最大的特征图集对应的CAM
print(img_name + ' output for the top1 prediction: %s' % class_[idx[0]])
(四)可视化与保存结果
img = cv2.imread(img_path)
height, width, _ = img.shape
heatmap = cv2.applyColorMap(cv2.resize(CAMs[0], (width, height)), cv2.COLORMAP_JET)
result = heatmap * 0.3 + img * 0.5text = '%s %.2f%%' % (class_[idx[0]], probs[0] * 100)
cv2.putText(result, text, (210, 40), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.9,color=(123, 222, 238), thickness=2, lineType=cv2.LINE_AA)CAM_RESULT_PATH = r'/data/heatmap/'
if not os.path.exists(CAM_RESULT_PATH):os.mkdir(CAM_RESULT_PATH)
image_name_ = img_name.split(".")[-2]
cv2.imwrite(os.path.join(CAM_RESULT_PATH, image_name_ + '_heatmap.jpg'), result)
四、实验结果与分析
通过上述代码,我成功生成了输入图像的注意力热图,并将其与原图叠加显示。从结果可以看出,注意力热图清晰地标注出了模型在做出预测时关注的区域。例如,在对“car”类别进行预测时,热图主要集中在车辆的轮廓和关键部位,这表明模型能够准确地识别出车辆的特征区域。
@浙大疏锦行