当前位置: 首页 > news >正文

pytorch_grad_cam 库学习笔记——基类ActivationsAndGradient

pytorch_grad_cam 是一个包含用于计算机视觉的可解释 AI 的最先进方法的软件包。 这可用于诊断模型预测,无论是在生产中还是在 开发模型。 其目的还在于作为研究新可解释性方法的算法和指标的基准。
pytorch_grad_cam 官方源码 https://github.com/jacobgil/pytorch-grad-cam
pytorch_grad_cam 官方教程 https://jacobgil.github.io/pytorch-gradcam-book/introduction.html
在./pytorch-grad-cam/pytorch_grad_cam/activations_and_gradients.py里定义了名为ActivationsAndGradients的基类,是实现类激活映射(CAM)算法的核心组件之一,其主要功能是利用 PyTorch 的 Hook 机制,在模型的前向和反向传播过程中,捕获指定目标层(target_layers)的激活值(activations)和梯度(gradients)。
本篇文章主要在这里对基类ActivationsAndGradients进行逐步分析,以理解库函数原理。

ActivationsAndGradients 类

ActivationsAndGradients 类是实现类激活映射(CAM)算法的核心组件之一,其主要功能是利用 PyTorch 的 Hook 机制,在模型的前向和反向传播过程中,捕获指定目标层(target_layers)的激活值(activations)和梯度(gradients)。
以下是该类的详细解析:

1. init(self, model, target_layers, reshape_transform, detach=True)

class ActivationsAndGradients:""" Class for extracting activations andregistering gradients from targetted intermediate layers """def __init__(self, model, target_layers, reshape_transform, detach=True):self.model = modelself.gradients = []self.activations = []self.reshape_transform = reshape_transformself.detach = detachself.handles = []for target_layer in target_layers:self.handles.append(target_layer.register_forward_hook(self.save_activation))# Because of https://github.com/pytorch/pytorch/issues/61519,# we don't use backward hook to record gradients.self.handles.append(target_layer.register_forward_hook(self.save_gradient))

功能:初始化 ActivationsAndGradients 实例。

参数:

  • model: 要分析的 PyTorch 模型。
  • target_layers: 一个包含目标 torch.nn.Module 层的列表(如 model.layer4 或 model.features)。这些层的激活和梯度将被捕获。
  • reshape_transform: 一个可选的函数,用于重塑激活和梯度。这在处理某些模型(如 Vision Transformers)时非常关键,因为它们的特征图形状(如 [batch, num_patches, features])与标准卷积网络的 [batch, channels, height, width] 不同,需要转换以便后续处理。
  • detach: 布尔值。如果为 True,则在捕获后将激活和梯度从计算图中分离(detach())并转移到 CPU。这可以节省 GPU 内存,并防止意外的梯度累积。如果为 False,则保留原始的张量(在 GPU 上且与计算图相连)。

关键操作:

  1. 初始化存储列表:self.gradients = [] 和 self.activations = []。
  2. 初始化 self.handles = [] 以存储注册的 Hook 句柄,便于后续移除。
  3. 注册 Hook:
  • 为 target_layers 中的每一个层,调用 register_forward_hook(self.save_activation)。这会在该层的前向传播结束后,自动调用 save_activation 方法,捕获其输出(即激活值)。
  • 为同一个层,再次调用 register_forward_hook(self.save_gradient)。注意:这里没有使用 register_backward_hook。注释中提到了一个 PyTorch 的 issue (#61519),暗示使用后向 Hook 可能存在问题。因此,这里采用了另一种方法:在前向 Hook 中,为层的输出张量 output 注册一个梯度 Hook (output.register_hook(_store_grad))。当反向传播到达该 output 张量时,_store_grad 函数就会被调用。

2. save_activation(self, module, input, output)

    def save_activation(self, module, input, output):activation = outputif self.detach:if self.reshape_transform is not None:activation = self.reshape_transform(activation)self.activations.append(activation.cpu().detach())else:self.activations.append(activation)

功能:

前向 Hook 的回调函数,用于捕获目标层的激活值。

参数(由 PyTorch 自动提供):

module: 调用此 Hook 的层(即 target_layer)。
input: 该层的输入张量(通常用不到)。
output: 该层的输出张量(即激活值)。

流程:

  • 将 output 赋值给 activation。
  • 如果提供了 reshape_transform 函数,则对 activation 进行重塑。
  • 根据 detach 参数决定如何存储:
    如果 detach=True:将 activation 移动到 CPU 并从计算图中分离,然后添加到 self.activations 列表。
    否则:直接将原始的 activation 添加到列表。

结果:

每次前向传播后,self.activations 列表会按顺序存储所有 target_layers 的激活值。

3. save_gradient(self, module, input, output)

    def save_gradient(self, module, input, output):if not hasattr(output, "requires_grad") or not output.requires_grad:# You can only register hooks on tensor requires grad.return# Gradients are computed in reverse orderdef _store_grad(grad):if self.detach:if self.reshape_transform is not None:grad = self.reshape_transform(grad)self.gradients = [grad.cpu().detach()] + self.gradientselse:self.gradients = [grad] + self.gradientsoutput.register_hook(_store_grad)

功能:前向 Hook 的回调函数,用于为目标层的输出张量注册一个梯度 Hook。这个梯度 Hook 会在反向传播时被触发。

参数:同 save_activation。

流程:

  1. 检查 output 是否需要梯度 (requires_grad)。如果不需要(例如,某些层的输出是整数或布尔值),则直接返回,不注册 Hook。
  2. 定义一个内部函数 _store_grad(grad):
  • 这个函数是真正的梯度 Hook,它接收反向传播计算出的梯度 grad 作为输入。
  • 如果提供了 reshape_transform,则对 grad 进行重塑。
  • 根据 detach 参数决定如何存储:
    ** 如果 detach=True:将 grad 移动到 CPU 并分离,然后插入到 * * self.gradients 列表的开头 ([grad.cpu().detach()] + self.gradients)。
    ** 否则:直接将 grad 插入到列表开头。
  • 为什么插入开头? 因为反向传播是从后往前进行的。最后层的梯度先计算,最先被捕获。为了保持 self.gradients 列表的顺序与 self.activations 和 target_layers 的顺序一致,需要将新捕获的梯度放在列表前面。
  1. 调用 output.register_hook(_store_grad),将 _store_grad 函数注册为 output 张量的梯度 Hook。

结果:在反向传播过程中,每当计算到某个 target_layer 的输出梯度时,_store_grad 就会被调用,该梯度被处理后按正确的顺序存储在 self.gradients 列表中。

4. call(self, x)

    def __call__(self, x):self.gradients = []self.activations = []return self.model(x)

功能:使 ActivationsAndGradients 对象可以像函数一样被调用。

流程:

  1. 在每次调用前,清空self.gradients 和 self.activations 列表。这是非常重要的,确保了每次调用捕获的都是本次前向/反向传播的数据,不会与之前的结果混合。
  2. 调用 self.model(x) 执行模型的前向传播。在此过程中,所有注册的 Hook 都会被触发,save_activation 会捕获激活值,save_gradient 会为输出张量注册梯度 Hook。

返回值:模型的前向输出(self.model(x) 的结果)。

副作用:self.activations 和 self.gradients 列表被填充。

5. release(self)

    def release(self):for handle in self.handles:handle.remove()

功能:

移除所有已注册的 Hook。

流程:

遍历 self.handles 列表,调用每个 handle.remove()。

重要性:

这是资源管理的关键步骤。如果不移除 Hook,它们会一直存在于模型中,导致:

  1. 内存泄漏:捕获的激活和梯度会持续累积。
  2. 性能下降:每次前向/反向传播都会执行不必要的 Hook 函数。
  3. 错误:可能干扰模型的其他操作。

调用时机:通常在 BaseCAM 的 delexit 方法中调用。

总结

ActivationsAndGradients 类巧妙地利用了 PyTorch 的 Hook 机制:

  1. 捕获激活:通过 register_forward_hook 在前向传播后直接捕获目标层的输出。
  2. 捕获梯度:通过在前向 Hook 中为输出张量注册 register_hook,在反向传播时捕获其梯度,并通过将新梯度插入列表开头来保证顺序正确。
  3. 灵活性:支持 reshape_transform 以适应不同模型架构。
  4. 内存管理:通过 detach 选项控制是否保留计算图,并通过 release 方法确保 Hook 被正确移除,防止资源泄漏。
  5. 易用性:提供 call 接口,使得用户只需调用一次即可完成前向传播并自动捕获所需数据。

这个类是 BaseCAM 及其所有子类能够工作的基石,它透明地拦截了模型内部的计算过程,为 CAM 算法提供了必需的中间数据。

http://www.dtcms.com/a/351731.html

相关文章:

  • vue2 和 vue3 生命周期的区别
  • 【Android】不同系统API版本_如何进行兼容性配置
  • 2014-2024高教社杯全国大学生数学建模竞赛赛题汇总预览分析
  • VMDK 文件
  • 软考-系统架构设计师 计算机系统基础知识详细讲解二
  • springcloud篇5-微服务保护(Sentinel)
  • Spring Boot mybatis-plus 多数据源配置
  • 【CVE-2025-5419】(内附EXP) Google Chrome 越界读写漏洞【内附EXP】
  • Kafka面试精讲 Day 1:Kafka核心概念与分布式架构
  • Elasticsearch中的协调节点
  • 详解kafka基础(一)
  • JavaScript常用的算法详解
  • Cherry-pick冲突与Git回滚
  • Oracle跟踪及分析方法
  • 力扣100+补充大完结
  • MySql 事务 锁
  • 推荐系统学习笔记(十四)-粗排三塔模型
  • 庖丁解牛:深入解析Oracle SQL语言的四大分类——DML、DDL、DCL、TCL
  • KubeBlocks for Oracle 容器化之路
  • 高校党建系统设计与实现(代码+数据库+LW)
  • 从零开始的 Docker 之旅
  • HIVE的高频面试UDTF函数
  • 【软考论文】论面向对象建模方法(动态、静态)
  • 无人机倾斜摄影农田航线规划
  • HTML应用指南:利用GET请求获取中国银行人民币存款利率数据
  • SciPy科学计算与应用:SciPy线性代数模块入门-矩阵运算与应用
  • 精确位置定位,AR交互助力高效作业流程​
  • 余承东:鸿蒙智行累计交付突破90万辆
  • 机器人视频感知架构深度解析:7条技术法则,打造低延迟实时感知与交互
  • 【ROS2】 忽略局域网多机通信导致数据接收的bug