当前位置: 首页 > news >正文

pytorch_grad_cam 库学习笔记——基类BaseCAM

pytorch_grad_cam 是一个包含用于计算机视觉的可解释 AI 的最先进方法的软件包。 这可用于诊断模型预测,无论是在生产中还是在 开发模型。 其目的还在于作为研究新可解释性方法的算法和指标的基准。
pytorch_grad_cam 官方源码 https://github.com/jacobgil/pytorch-grad-cam
pytorch_grad_cam 官方教程 https://jacobgil.github.io/pytorch-gradcam-book/introduction.html
在./pytorch-grad-cam/pytorch_grad_cam/base_cam.py里定义了名为BaseCAM的基类,用于实现不同种类的类激活映射(Class Activation Mapping, CAM)算法。这些算法通常被用来可视化深度学习模型中卷积神经网络在进行分类时所关注的图像区域。
本篇文章主要在这里对基类BaseCAM进行逐步分析,以理解库函数原理。

BaseCAM类

BaseCAM 是一个设计精良的基类,为各种 CAM(Class Activation Mapping)算法(如 Grad-CAM, Eigen-CAM, XGrad-CAM 等)提供了一个通用的框架。它的核心思想是:获取目标层的激活(activations)和梯度(gradients),然后根据特定算法计算权重,最后将权重与激活相乘并求和,生成热力图。

以下是该基类的主要功能和结构:

1. 初始化函数 (init):接受模型、目标层列表、reshape变换函数等参数,并初始化了与模型激活和梯度相关的组件。

class BaseCAM:def __init__(self,model: torch.nn.Module,target_layers: List[torch.nn.Module],reshape_transform: Callable = None,compute_input_gradient: bool = False,uses_gradients: bool = True,tta_transforms: Optional[tta.Compose] = None,detach: bool = True,) -> None:self.model = model.eval()self.target_layers = target_layers# Use the same device as the model.self.device = next(self.model.parameters()).deviceif 'hpu' in str(self.device):try:import habana_frameworks.torch.core as htcoreexcept ImportError as error:error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}."raise errorself.__htcore = htcoreself.reshape_transform = reshape_transformself.compute_input_gradient = compute_input_gradientself.uses_gradients = uses_gradientsif tta_transforms is None:self.tta_transforms = tta.Compose([tta.HorizontalFlip(),tta.Multiply(factors=[0.9, 1, 1.1]),])else:self.tta_transforms = tta_transformsself.detach = detachself.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform, self.detach)""" Get a vector of weights for every channel in the target layer.Methods that return weights channels,will typically need to only implement this function. """
功能:

初始化 BaseCAM 实例。

参数:

model: 要解释的 PyTorch 模型。
target_layers: 一个包含目标 torch.nn.Module 层的列表。CAM 热力图将基于这些层的激活和梯度生成。通常是最深层的卷积层。
reshape_transform: 一个可选的函数。在某些模型(如 Vision Transformers)中,特征图的形状可能与输入图像不直接对应,需要此函数进行重塑。例如,将 ViT 的 [batch, num_patches, features] 重塑回 [batch, channels, height, width]。
compute_input_gradient: 布尔值。如果为 True,则会计算输入张量的梯度。这主要用于像 InputGradCAM 这样的变体,它需要输入的梯度。
uses_gradients: 布尔值。如果为 True,则算法需要计算梯度(如 Grad-CAM)。如果为 False,则算法不依赖梯度(如原始的 CAM)。这决定了在 forward 方法中是否执行 backward()。
tta_transforms: 一个 ttach.Compose 对象,包含用于测试时增强(Test Time Augmentation, TTA)的变换(如水平翻转、缩放)。如果为 None,则使用默认的翻转和乘法因子增强。TTA 用于平滑最终的热力图。
detach: 布尔值。控制是否在反向传播后从计算图中分离梯度。True 更安全,避免意外的梯度累积。

关键操作:

将模型设置为评估模式 (model.eval())。
获取模型所在的设备(CPU/GPU)。
初始化 ActivationsAndGradients 对象(来自 pytorch-grad-cam 库)。这个对象是核心,它通过 PyTorch 的 Hook 机制,在前向传播时捕获 target_layers 的激活,在反向传播时捕获它们的梯度。
设置 tta_transforms。

2. 权重计算函数 (get_cam_weights):这是一个需要子类实现的抽象方法,目的是为每个通道计算权重,从而生成CAM热图。

    def get_cam_weights(self,input_tensor: torch.Tensor,target_layers: List[torch.nn.Module],targets: List[torch.nn.Module],activations: torch.Tensor,grads: torch.Tensor,) -> np.ndarray:raise Exception("Not Implemented")
功能:

这是一个抽象方法,必须由继承 BaseCAM 的子类实现。它是不同 CAM 算法的核心区别所在。

作用:

根据输入张量、目标层、目标类别、目标层的激活值和梯度值,计算出每个通道的权重。

输入:

input_tensor: 输入的图像张量。
target_layers: 目标层列表。
targets: 目标类别列表(ClassifierOutputTarget 对象)。
activations: 目标层的激活值(torch.Tensor 或 np.ndarray)。
grads: 目标层的梯度值(torch.Tensor 或 np.ndarray)。

输出:

一个 np.ndarray,形状为 (B, C),其中 B 是批次大小,C 是目标层的通道数。每个元素代表对应通道的权重。

示例:

Grad-CAM: 权重是梯度在空间维度(H, W)上的平均值 (weights = grads.mean(dim=[2, 3]))。
Eigen-CAM: 权重是激活值进行 SVD 后第一个主成分的方向向量。
XGrad-CAM: 权重是梯度与激活的逐元素乘积再取平均。

3. 热图生成函数 (get_cam_image):根据给定的输入张量、目标层、激活值和梯度来计算CAM热图。

    def get_cam_image(self,input_tensor: torch.Tensor,target_layer: torch.nn.Module,targets: List[torch.nn.Module],activations: torch.Tensor,grads: torch.Tensor,eigen_smooth: bool = False,) -> np.ndarray:weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)if isinstance(activations, torch.Tensor):activations = activations.cpu().detach().numpy()# 2D convif len(activations.shape) == 4:weighted_activations = weights[:, :, None, None] * activations# 3D convelif len(activations.shape) == 5:weighted_activations = weights[:, :, None, None, None] * activationselse:raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")if eigen_smooth:cam = get_2d_projection(weighted_activations)else:cam = weighted_activations.sum(axis=1)return cam

功能:

根据 get_cam_weights 计算出的权重和激活值,生成单个目标层的 CAM 热力图。

流程:

  1. 调用 get_cam_weights 得到权重 weights。
  2. 将 activations 转换为 NumPy 数组。
  3. 加权激活:将权重与激活值相乘。
    对于 2D 卷积 (形状 (B, C, H, W)):weights[:, :, None, None] * activations。None 用于扩展维度,使广播(broadcasting)能够将每个通道的权重应用到该通道的所有空间位置。
    对于 3D 卷积 (形状 (B, C, D, H, W)):weights[:, :, None, None, None] * activations。
  4. 生成热力图:
    如果 eigen_smooth=True:调用 get_2d_projection(即 Eigen-CAM 的方法),使用 SVD 对加权后的激活进行降维,生成更平滑的热力图。
    否则:直接在通道维度(axis=1)上求和,得到 (B, H, W) 或 (B, D, H, W) 的热力图。

输出:

一个 np.ndarray,形状为 (B, H, W) 或 (B, D, H, W),表示单个目标层的 CAM 热力图。

4. 前向传播函数 (forward):执行模型的前向传播,计算输出并基于目标类别计算损失,然后计算梯度。

    def forward(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False) -> np.ndarray:input_tensor = input_tensor.to(self.device)if self.compute_input_gradient:input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)self.outputs = outputs = self.activations_and_grads(input_tensor)if targets is None:target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)targets = [ClassifierOutputTarget(category) for category in target_categories]if self.uses_gradients:self.model.zero_grad()loss = sum([target(output) for target, output in zip(targets, outputs)])if self.detach:loss.backward(retain_graph=True)else:# keep the computational graph, create_graph = True is needed for hvptorch.autograd.grad(loss, input_tensor, retain_graph = True, create_graph = True)# When using the following loss.backward() method, a warning is raised: "UserWarning: Using backward() with create_graph=True will create a reference cycle"# loss.backward(retain_graph=True, create_graph=True)if 'hpu' in str(self.device):self.__htcore.mark_step()# In most of the saliency attribution papers, the saliency is# computed with a single target layer.# Commonly it is the last convolutional layer.# Here we support passing a list with multiple target layers.# It will compute the saliency image for every image,# and then aggregate them (with a default mean aggregation).# This gives you more flexibility in case you just want to# use all conv layers for example, all Batchnorm layers,# or something else.cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)return self.aggregate_multi_layers(cam_per_layer)

功能:

执行 CAM 热力图生成的主要流程。

流程:

  1. 前向传播:将 input_tensor 移动到模型设备上。如果 compute_input_gradient=True,则将输入设置为需要梯度。调用 self.activations_and_grads(input_tensor) 执行前向传播,并捕获 target_layers 的激活。同时得到模型的输出 outputs。
  2. 确定目标:如果 targets 为 None,则自动选择模型预测得分最高的类别作为目标。
  3. 反向传播:如果 uses_gradients=True,则执行反向传播。
  • 计算损失:loss = sum([target(output) for …]),即所有目标类别得分的总和。
  • 清零梯度:model.zero_grad()。
  • 执行 backward() 或 torch.autograd.grad()(取决于detach 参数),计算梯度。HPU (Habana Gaudi) 设备需要特殊处理 (htcore.mark_step())。
  1. 生成各层热力图:调用 compute_cam_per_layer 为每个 target_layer 生成热力图。
  2. 聚合热力图:调用 aggregate_multi_layers 将所有目标层生成的热力图聚合为一个最终的热力图。

输出:

一个 np.ndarray,形状为 (B, H, W) 或 (B, D, H, W),表示最终的 CAM 热力图。

5. get_target_width_height(self, input_tensor)

    def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]:if len(input_tensor.shape) == 4:width, height = input_tensor.size(-1), input_tensor.size(-2)return width, heightelif len(input_tensor.shape) == 5:depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3)return depth, width, heightelse:raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.")

功能:

获取输入张量的空间维度(宽度、高度或深度、宽度、高度),用于后续热力图的缩放。

输入:

input_tensor (形状 (B, C, H, W) 或 (B, C, D, H, W))。

输出:

一个元组,对于 2D 图像是 (width, height),对于 3D 图像是 (depth, width, height)。

6. compute_cam_per_layer(self, input_tensor, targets, eigen_smooth)

    def compute_cam_per_layer(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool) -> np.ndarray:if self.detach:activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]else:activations_list = [a for a in self.activations_and_grads.activations]grads_list = [g for g in self.activations_and_grads.gradients]target_size = self.get_target_width_height(input_tensor)cam_per_target_layer = []# Loop over the saliency image from every layerfor i in range(len(self.target_layers)):target_layer = self.target_layers[i]layer_activations = Nonelayer_grads = Noneif i < len(activations_list):layer_activations = activations_list[i]if i < len(grads_list):layer_grads = grads_list[i]cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)cam = np.maximum(cam, 0)scaled = scale_cam_image(cam, target_size)cam_per_target_layer.append(scaled[:, None, :])return cam_per_target_layer

功能:

为 target_layers 列表中的每一个目标层计算 CAM 热力图。

流程:

从 self.activations_and_grads 中获取所有 target_layers 的激活和梯度(如果 detach=True,则转换为 NumPy 数组)。
获取目标空间尺寸。
遍历 target_layers:
获取当前层的激活和梯度。
调用 get_cam_image 生成该层的热力图 cam。
调用 scale_cam_image(cam, target_size) 将热力图归一化到 [0, 1] 并缩放到目标尺寸。
将处理后的热力图添加到列表 cam_per_target_layer 中(添加一个新维度 [:, None, :] 以便后续拼接)。

输出:

一个列表,包含每个目标层生成的、已缩放和归一化的热力图。

7. 层聚合函数 (aggregate_multi_layers):对多个目标层生成的CAM热图进行聚合处理,通常是通过取平均的方式。

    def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray:cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)cam_per_target_layer = np.maximum(cam_per_target_layer, 0)result = np.mean(cam_per_target_layer, axis=1)return scale_cam_image(result)

功能:

将 compute_cam_per_layer 生成的多个热力图聚合成一个最终的热力图。

流程:

将 cam_per_target_layer 列表中的热力图在通道维度(axis=1)上拼接起来,得到一个形状为 (B, N, H, W) 的数组,其中 N 是目标层的数量。
将所有负值置为 0(np.maximum(…, 0)),确保热力图非负。
在通道维度(axis=1)上取平均值,得到 (B, H, W) 的热力图。
再次调用 scale_cam_image(result) 进行最终的归一化。

输出:

一个 (B, H, W) 的 np.ndarray,表示聚合后的最终热力图。

设计思想:

虽然大多数 CAM 只使用一个目标层,但此设计允许用户指定多个层(如所有卷积层),然后通过平均等方式聚合它们的结果,提供更大的灵活性。

8. 平滑处理函数 (forward_augmentation_smoothing):利用测试时数据增强(TTA)技术,对多个增强样本的CAM结果进行平滑处理,以得到更稳定的结果。

    def forward_augmentation_smoothing(self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False) -> np.ndarray:cams = []for transform in self.tta_transforms:augmented_tensor = transform.augment_image(input_tensor)cam = self.forward(augmented_tensor, targets, eigen_smooth)# The ttach library expects a tensor of size BxCxHxWcam = cam[:, None, :, :]cam = torch.from_numpy(cam)cam = transform.deaugment_mask(cam)# Back to numpy float32, HxWcam = cam.numpy()cam = cam[:, 0, :, :]cams.append(cam)cam = np.mean(np.float32(cams), axis=0)return cam

功能:

使用测试时增强 (TTA) 来平滑和稳定 CAM 热力图。这是 forward 的增强版本。

流程:

遍历 self.tta_transforms 中的每一个增强变换(如水平翻转)。
对 input_tensor 应用该变换得到 augmented_tensor。
调用 self.forward(augmented_tensor, targets, eigen_smooth) 生成增强后图像的 CAM 热力图 cam。
将 cam 转换为张量并添加一个维度 [:, None, :, :](符合 ttach 库要求)。
使用 transform.deaugment_mask(cam) 将生成的热力图“反变换”回原始图像的坐标系。
将反变换后的热力图转换回 NumPy 数组,并移除添加的维度。
将所有反变换后的热力图收集到 cams 列表中。
对 cams 列表中的所有热力图在批次维度上取平均值,得到最终平滑的热力图。

输出:

一个 (B, H, W) 的 np.ndarray,表示经过 TTA 平滑后的最终热力图。

优点:

能有效减少热力图的噪声,使其更稳定、更鲁棒。

9. 调用接口 (call):提供一个简便的方法来调用 forward 或者 forward_augmentation_smoothing 方法,取决于是否启用了平滑选项。

    def __call__(self,input_tensor: torch.Tensor,targets: List[torch.nn.Module] = None,aug_smooth: bool = False,eigen_smooth: bool = False,) -> np.ndarray:# Smooth the CAM result with test time augmentationif aug_smooth is True:return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)return self.forward(input_tensor, targets, eigen_smooth)

功能:

BaseCAM 类的主调用接口。用户通常通过 cam(input_tensor, targets) 来使用它。

逻辑:

如果 aug_smooth=True,则调用 forward_augmentation_smoothing。
否则,调用 forward。

参数:

aug_smooth: 是否启用 TTA 平滑。
eigen_smooth: 是否在 get_cam_image 中使用 SVD 降维(Eigen-CAM 方式)。

10. del(self) 和 enter, exit

    def __del__(self):self.activations_and_grads.release()def __enter__(self):return selfdef __exit__(self, exc_type, exc_value, exc_tb):self.activations_and_grads.release()if isinstance(exc_value, IndexError):# Handle IndexError here...print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")return True

del: 析构函数。当 BaseCAM 对象被销毁时,调用 self.activations_and_grads.release() 释放 Hook 资源,避免内存泄漏。
enter, exit: 实现上下文管理器协议。允许使用 with BaseCAM(…) as cam: 语法。exit 会确保在退出 with 块时调用 release(),并捕获 IndexError 异常(尽管只是打印警告并返回 True 表示已处理,这可能不是最佳实践)。

总结

BaseCAM 是一个高度模块化和可扩展的基类:

  1. 核心流程:forward 定义了标准流程(前向->反向->生成各层热力图->聚合)。
  2. 算法核心:get_cam_weights 是抽象方法,子类通过重写它来实现不同的 CAM 算法。
  3. 灵活性:
    支持多个目标层,并通过 aggregate_multi_layers 聚合。
    支持 reshape_transform 以适应非标准模型(如 ViT)。
    支持 TTA 平滑 (forward_augmentation_smoothing)。
    支持 Eigen-CAM 的 SVD 降维 (eigen_smooth)。
  4. 资源管理:通过 del 和上下文管理器确保 Hook 资源被正确释放。
  5. 用户接口:call 提供了简洁的调用方式。

这个设计使得开发者可以轻松地通过继承 BaseCAM 并实现 get_cam_weights 方法来创建新的 CAM 变体,同时复用其强大的前处理、后处理和资源管理功能。

http://www.dtcms.com/a/352163.html

相关文章:

  • 使用 Docker、Jenkins、Harbor 和 GitLab 构建 CI/CD 流水线
  • Unity:游戏性能优化!之把分散在各个游戏角色GameObject上的脚本修改为在一个脚本中运行。这样做会让游戏运行更高效?
  • Caddy + CoreDNS 深度解析:从功能架构到性能优化实践(下)
  • 【BurpSuite 插件开发】实战篇(十六-终章)性能优化实践:线程管理到正则匹配的全方位提升
  • Python爬虫实战:研究开源的高性能代理池,构建电商数据采集和分析系统
  • STM32物联网项目---ESP8266微信小程序结合OneNET平台MQTT实现STM32单片机远程智能控制---云平台篇(一)
  • 深度学习——神经网络(PyTorch 实现 MNIST 手写数字识别案例)
  • 数据集数量与神经网络参数关系分析
  • Vibe 编程:下一代开发者范式的深度解析
  • 扩展现有的多模块 Starter
  • 2025本地部署overleaf
  • 售价3499美元,英伟达Jetson Thor实现机器人与物理世界的实时智能交互
  • 09-SpringBoot入门案例
  • 嵌入式学习笔记-LINUX系统编程阶段-DAY01脚本
  • 第四章:条件判断
  • VueFlow画布可视化——js技能提升
  • 安全测试、web探测、httpx
  • vue2和vue3的对比
  • Android 属性系统
  • 蓝思科技中报:深耕业务增量,AI硬件打开想象空间
  • Pandas vs Polars Excel 数据加载对比报告
  • Coze Studio系统架构深度剖析:从分层设计到领域驱动的技术实践- 第二篇
  • vue实现拖拉拽效果,类似于禅道首页可拖拽排布展示内容(插件-Grid Layout)
  • 用 Allure 生成 pytest 测试报告:从安装到使用全流程
  • STM32 定时器(互补输出+刹车)
  • yggjs_rbutton React按钮组件v1.0.0 多主题系统使用指南
  • 什么叫API对接HR系统?
  • 2025年8月技术问答第3期
  • 03MySQL——DCL权限控制,四种常用函数解析
  • SSM入门到实战: 3.6 SpringMVC RESTful API开发