当前位置: 首页 > news >正文

python打卡day42

Grad-CAM与Hook函数

知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyTorch 提供了一种强大的工具——hook 函数,它允许我们在不修改模型结构的情况下,获取或修改中间层的信息。常用场景如下:

  1. 调试与可视化中间层输出
  2. 特征提取:如在图像分类模型中提取高层语义特征用于下游任务
  3. 梯度分析与修改: 在训练过程中,对某些层进行梯度裁剪或缩放,以改变模型训练的动态
  4. 模型压缩:在推理阶段对特定层的输出应用掩码(如剪枝后的模型权重掩码),实现轻量化推理

1、回调函数

Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,其中回调函数作为参数传入,所以在定义的时候一般用callback来命名

在 PyTorch 的 Hook API 中,回调参数通常命名为 hook,PyTorch 的 Hook 机制基于其动态计算图系统:

  1. 当你注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数
  2. 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数
  3. Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)

2、lambda函数

在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式:lambda 参数列表: 表达式

  • 参数列表:可以是单个参数、多个参数或无参数
  • 表达式:函数的返回值(无需 return 语句,表达式结果直接返回)

举个例子

# 定义匿名函数:计算平方
square = lambda x: x ** 2# 调用
print(square(5))  # 输出: 25

3、hook函数

PyTorch 提供了两种主要的 hook:

  • Module Hooks(模块钩子):用于监听整个模块的输入和输出
  • Tensor Hooks:用于监听张量的梯度

(1)模块钩子

允许我们在模块的输入或输出经过时进行监听。PyTorch 提供了两种模块钩子:

  1. register_forward_hook:在模块的前向传播完成后立即被调用,这个函数可以访问模块的输入和输出,但不能修改
  2. register_backward_hook:在反向传播过程中被调用的,可以用来获取或修改梯度信息

前向钩子举个例子

# 创建模型实例
model = SimpleModel()# 创建一个列表用于存储中间层的输出
conv_outputs = []# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):print(f"钩子被调用!模块类型: {type(module)}")print(f"输入形状: {input[0].shape}") #  input是一个元组,对应 (image, label)print(f"输出形状: {output.shape}")# 保存卷积层的输出用于后续分析# 使用detach()避免追踪梯度,防止内存泄漏conv_outputs.append(output.detach())# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()

反向钩子

# 定义一个存储梯度的列表
conv_gradients = []# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):print(f"反向钩子被调用!模块类型: {type(module)}")print(f"输入梯度数量: {len(grad_input)}")print(f"输出梯度数量: {len(grad_output)}")# 保存梯度供后续分析conv_gradients.append((grad_input, grad_output))# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()# 释放钩子
hook_handle.remove()

(2)张量钩子

PyTorch 还提供了张量钩子,允许我们直接监听和修改张量的梯度。张量钩子有两种:

  1. register_hook:用于监听张量的梯度
  2. register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):print(f"原始梯度: {grad}")# 修改梯度,例如将梯度减半return grad / 2# 在y上注册钩子
hook_handle = y.register_hook(tensor_hook)# 计算梯度,梯度会从z反向传播经过y到x,此时调用钩子函数
z.backward()print(f"x的梯度: {x.grad}")# 释放钩子
hook_handle.remove()

4、Grad-CAM

一个可视化算法,通过梯度信息用热力图显示图片中哪些区域让CNN做出了某个分类决定(比如为什么认为这是“猫”),原理:

  • 梯度计算:看最后几层特征图的梯度,哪个特征图对预测“猫”的贡献大
  • 加权融合:把重要的特征图合并成一张热力图(重要区域更亮)
  • 叠加显示:把热力图盖在原图上,一眼看出猫的脸/耳朵等关键部位被高亮了
# Grad-CAM实现
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注册钩子,用于获取目标层的前向传播输出和反向传播梯度self.register_hooks()def register_hooks(self):# 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目标层注册前向钩子和反向钩子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向传播,得到模型输出model_output = self.model(input_image)if target_class is None:# 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影响self.model.zero_grad()# 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 获取之前保存的目标层的梯度和激活值gradients = self.gradientsactivations = self.activations# 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响cam = F.relu(cam)# 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_classidx = 102  # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")# 转换图像以便可视化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可视化
plt.figure(figsize=(12, 4))# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()

@浙大疏锦行

相关文章:

  • Mask_RCNN 环境配置及训练
  • leetcode hot100 二叉树(一)
  • 第七部分:第四节 - 在 NestJS 应用中集成 MySQL (使用 TypeORM):结构化厨房的原材料管理系统
  • 剑指offer hot100 第三周
  • 查看make命令执行后涉及的预编译宏定义的值
  • java synchronized关键字用法
  • io流2——字节输入流,文件拷贝
  • Codeforces 1027 Div3(ABCDEF)
  • Java网络编程基础:从阻塞式I/O到线程池模型
  • DAY 34 超大力王爱学Python
  • C++ —— STL容器——string类
  • ps中通过拷贝的图层和通过剪切的图层
  • java多线程与JUC
  • ck-editor5的研究 (4):初步使用 CKEditor5 的插件功能
  • Cesium快速入门到精通系列教程三
  • 高速串行接口
  • Spring Boot 4.0实战:构建高并发电商系统
  • ArkTS基础
  • spining-lidar的电机和激光雷达体(lidar-imu)之间的标定
  • VMware-VMRC-12.0.1-18113358安装包下载安装与使用(附下载)
  • 海口做网站公司/站长统计app进入网址新版
  • 网页在线代理服务器/湖南seo推广多少钱
  • 简单的cms建站系统/外贸推广哪个公司好
  • 建设局发公告的网站/关联词有哪些
  • 免费做印章的网站/怎么制作一个简单的网页
  • 郑州网站建设外包业务/aso优化分析