python打卡day42
Grad-CAM与Hook函数
知识点回顾
- 回调函数
- lambda函数
- hook函数的模块钩子和张量钩子
- Grad-CAM的示例
在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyTorch 提供了一种强大的工具——hook 函数,它允许我们在不修改模型结构的情况下,获取或修改中间层的信息。常用场景如下:
- 调试与可视化中间层输出
- 特征提取:如在图像分类模型中提取高层语义特征用于下游任务
- 梯度分析与修改: 在训练过程中,对某些层进行梯度裁剪或缩放,以改变模型训练的动态
- 模型压缩:在推理阶段对特定层的输出应用掩码(如剪枝后的模型权重掩码),实现轻量化推理
1、回调函数
Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,其中回调函数作为参数传入,所以在定义的时候一般用callback来命名
在 PyTorch 的 Hook API 中,回调参数通常命名为 hook,PyTorch 的 Hook 机制基于其动态计算图系统:
- 当你注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数
- 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数
- Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)
2、lambda函数
在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式:lambda 参数列表: 表达式
- 参数列表:可以是单个参数、多个参数或无参数
- 表达式:函数的返回值(无需 return 语句,表达式结果直接返回)
举个例子
# 定义匿名函数:计算平方
square = lambda x: x ** 2# 调用
print(square(5)) # 输出: 25
3、hook函数
PyTorch 提供了两种主要的 hook:
- Module Hooks(模块钩子):用于监听整个模块的输入和输出
- Tensor Hooks:用于监听张量的梯度
(1)模块钩子
允许我们在模块的输入或输出经过时进行监听。PyTorch 提供了两种模块钩子:
- register_forward_hook:在模块的前向传播完成后立即被调用,这个函数可以访问模块的输入和输出,但不能修改
- register_backward_hook:在反向传播过程中被调用的,可以用来获取或修改梯度信息
前向钩子举个例子
# 创建模型实例
model = SimpleModel()# 创建一个列表用于存储中间层的输出
conv_outputs = []# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):print(f"钩子被调用!模块类型: {type(module)}")print(f"输入形状: {input[0].shape}") # input是一个元组,对应 (image, label)print(f"输出形状: {output.shape}")# 保存卷积层的输出用于后续分析# 使用detach()避免追踪梯度,防止内存泄漏conv_outputs.append(output.detach())# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()
反向钩子
# 定义一个存储梯度的列表
conv_gradients = []# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):print(f"反向钩子被调用!模块类型: {type(module)}")print(f"输入梯度数量: {len(grad_input)}")print(f"输出梯度数量: {len(grad_output)}")# 保存梯度供后续分析conv_gradients.append((grad_input, grad_output))# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()# 释放钩子
hook_handle.remove()
(2)张量钩子
PyTorch 还提供了张量钩子,允许我们直接监听和修改张量的梯度。张量钩子有两种:
- register_hook:用于监听张量的梯度
- register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):print(f"原始梯度: {grad}")# 修改梯度,例如将梯度减半return grad / 2# 在y上注册钩子
hook_handle = y.register_hook(tensor_hook)# 计算梯度,梯度会从z反向传播经过y到x,此时调用钩子函数
z.backward()print(f"x的梯度: {x.grad}")# 释放钩子
hook_handle.remove()
4、Grad-CAM
一个可视化算法,通过梯度信息用热力图显示图片中哪些区域让CNN做出了某个分类决定(比如为什么认为这是“猫”),原理:
- 梯度计算:看最后几层特征图的梯度,哪个特征图对预测“猫”的贡献大
- 加权融合:把重要的特征图合并成一张热力图(重要区域更亮)
- 叠加显示:把热力图盖在原图上,一眼看出猫的脸/耳朵等关键部位被高亮了
# Grad-CAM实现
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注册钩子,用于获取目标层的前向传播输出和反向传播梯度self.register_hooks()def register_hooks(self):# 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目标层注册前向钩子和反向钩子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向传播,得到模型输出model_output = self.model(input_image)if target_class is None:# 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影响self.model.zero_grad()# 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 获取之前保存的目标层的梯度和激活值gradients = self.gradientsactivations = self.activations# 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响cam = F.relu(cam)# 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_classidx = 102 # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")# 转换图像以便可视化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可视化
plt.figure(figsize=(12, 4))# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()
@浙大疏锦行