当前位置: 首页 > news >正文

DAY 42 Grad-CAM与Hook函数-2025.10.6

Grad-CAM与Hook函数

知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

笔记:

1. 回调函数(Callback Function)

定义
回调函数是在特定事件触发时自动执行的函数,通常作为参数传递给另一个函数 / 框架,用于在流程中插入自定义逻辑(如日志记录、模型保存、早停等),核心是 “解耦主流程与自定义操作”。
核心作用

  • 不修改主流程代码,灵活扩展功能(如训练过程中记录日志、保存最优模型);
  • 响应特定事件(如 “训练 epoch 结束”“验证损失下降”“异常报错”)。

典型场景与代码示例(以 PyTorch 训练为例)
PyTorch 无内置 Callback 类,需自定义实现,常见于训练循环中触发:

import torch
import torch.nn as nn
from torch.optim import SGD
from torchvision.models import resnet18# 1. 定义回调函数:保存验证损失最小的模型
def save_best_model(model, val_loss, best_loss, save_path):if val_loss < best_loss:best_loss = val_losstorch.save(model.state_dict(), save_path)print(f"✅ 保存最优模型(验证损失:{val_loss:.4f})")return best_loss# 2. 定义回调函数:打印训练日志
def print_train_log(epoch, train_loss, val_loss, train_acc, val_acc):print(f"Epoch {epoch:3d} | 训练损失:{train_loss:.4f} | 验证损失:{val_loss:.4f}")print(f"          | 训练准确率:{train_acc:.2f}% | 验证准确率:{val_acc:.2f}%")# 3. 主训练流程(触发回调函数)
def train(model, train_loader, val_loader, epochs, lr=0.01):criterion = nn.CrossEntropyLoss()optimizer = SGD(model.parameters(), lr=lr)best_loss = float('inf')  # 初始化最优损失for epoch in range(1, epochs+1):# 训练阶段(省略细节,仅模拟损失/准确率)train_loss, train_acc = 0.5 - epoch*0.01, 60 + epoch*2  # 模拟下降/上升# 验证阶段val_loss, val_acc = 0.6 - epoch*0.008, 58 + epoch*1.8    # 模拟下降/上升# 触发回调函数1:打印日志print_train_log(epoch, train_loss, val_loss, train_acc, val_acc)# 触发回调函数2:保存最优模型best_loss = save_best_model(model, val_loss, best_loss, "best_model.pth")# 测试
model = resnet18(pretrained=False, num_classes=10)
train(model, train_loader=None, val_loader=None, epochs=5)  # 简化示例,实际需传入真实DataLoader

2. Lambda 函数(匿名函数)

定义
Lambda 函数是Python 中的匿名函数,通过lambda 参数: 表达式定义,仅能包含一行表达式,返回表达式结果,适合处理简单逻辑(无需单独定义函数)。

核心特点

  • 匿名性:无需指定函数名,仅用于临时调用;
  • 简洁性:一行代码完成简单运算(如加减乘除、排序 key);
  • 局限性:仅支持单个表达式,无法包含循环、条件判断(复杂逻辑需用def定义函数)。

典型场景与代码示例
场景 1:作为高阶函数参数(如sortedmap

# 示例1:sorted排序的key(按元组第二个元素排序)
data = [(1, 3), (4, 1), (2, 5)]
sorted_data = sorted(data, key=lambda x: x[1])  # key为lambda函数,提取第二个元素
print(sorted_data)  # 输出:[(4, 1), (1, 3), (2, 5)]# 示例2:map映射(对列表元素平方)
nums = [1, 2, 3, 4]
squared_nums = list(map(lambda x: x**2, nums))  # lambda函数实现平方运算
print(squared_nums)  # 输出:[1, 4, 9, 16]

场景 2:临时定义简单逻辑(如字典值过滤)

# 过滤字典中值大于10的键值对
scores = {"Alice": 8, "Bob": 12, "Charlie": 15}
high_scores = {k: v for k, v in scores.items() if v > 10}  # 结合字典推导式
# 或用filter(需转换为字典)
high_scores_filtered = dict(filter(lambda item: item[1] > 10, scores.items()))
print(high_scores_filtered)  # 输出:{'Bob': 12, 'Charlie': 15}

3. Hook 函数(模块钩子与张量钩子)

Hook 函数是PyTorch 中用于 “拦截” 模型中间结果(特征图、梯度)的工具,无需修改模型结构即可获取 / 修改中间变量,核心用于调试、可视化(如特征图提取)、梯度分析。
分为两类:模块钩子(Module Hook) 和 张量钩子(Tensor Hook)

3.1 模块钩子(Module Hook)

作用于nn.Module(如卷积层、全连接层),用于获取模块的输入、输出(前向钩子)或梯度(反向钩子)。

常用类型

钩子类型函数定义作用
前向钩子(Forward Hook)register_forward_hook(hook_fn)获取模块的输入张量、输出张量
反向钩子(Backward Hook)register_backward_hook(hook_fn)获取模块的输入梯度、输出梯度

代码示例:用前向钩子提取卷积层特征图

import torch
import torch.nn as nn
import matplotlib.pyplot as plt# 1. 定义简单CNN模型
class SimpleCNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)  # 卷积层1self.relu = nn.ReLU()self.conv2 = nn.Conv2d(16, 32, 3, padding=1) # 卷积层2self.fc = nn.Linear(32*8*8, 10)  # 假设输入图像28x28(实际需匹配)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.conv2(x)x = self.relu(x)x = x.view(x.size(0), -1)x = self.fc(x)return x# 2. 定义前向钩子函数:保存conv2的输出(特征图)
def get_conv2_output(module, input_tensor, output_tensor):"""module: 当前模块(此处为conv2)input_tensor: 模块输入(tuple类型,含一个张量)output_tensor: 模块输出(张量)"""global conv2_features  # 全局变量存储特征图conv2_features = output_tensor.detach()  #  detach()避免梯度追踪# 3. 注册钩子并测试
model = SimpleCNN()
# 给conv2层注册前向钩子
hook_handle = model.conv2.register_forward_hook(get_conv2_output)# 模拟输入(batch_size=1, 3通道, 28x28图像)
x = torch.randn(1, 3, 28, 28)
model(x)  # 前向传播,触发钩子,保存conv2_features# 可视化特征图(取前8个通道)
plt.figure(figsize=(12, 2))
for i in range(8):plt.subplot(1, 8, i+1)plt.imshow(conv2_features[0, i].numpy(), cmap='gray')  # 第0个样本,第i个通道plt.axis('off')
plt.title("Conv2层特征图(前8通道)")
plt.show()# 重要:移除钩子(避免内存泄漏)
hook_handle.remove()

3.2 张量钩子(Tensor Hook)

作用于张量(Tensor),用于获取 / 修改张量的梯度(仅反向传播时触发),常见于梯度裁剪、梯度分析。
常用类型

  • register_hook(hook_fn):为张量注册钩子,hook_fn接收张量的梯度作为参数,可返回修改后的梯度。

代码示例:用张量钩子打印梯度

import torch# 1. 定义张量并注册钩子
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  # 需要计算梯度# 钩子函数:打印张量的梯度
def print_gradient(grad):print(f"✅ 张量梯度:{grad}")# 注册钩子
grad_hook = x.register_hook(print_gradient)# 2. 反向传播(触发钩子)
y = x.sum()  # 简单运算:求和
y.backward()  # 反向传播,计算x的梯度# 3. 移除钩子
grad_hook.remove()

输出
✅ 张量梯度:tensor([1., 1., 1.])(sum 函数对每个元素的梯度均为 1)

4. Grad-CAM(Gradient-weighted Class Activation Mapping)

定义
Grad-CAM 是模型可视化技术,通过计算目标类对最后一个卷积层特征图的梯度,加权得到 “类激活图(CAM)”,从而定位模型关注的图像区域(解释模型为何预测该类)。
核心原理

  1. 前向传播:获取最后一个卷积层的特征图 A(形状:[C, H, W],C 为通道数);
  2. 反向传播:计算目标类对特征图的梯度 dy/dA(形状:[C, H, W]);
  3. 计算通道权重:对每个通道的梯度全局平均池化(GAP),得到权重 α_c = mean(dy/dA_c)(形状:[C]);
  4. 生成 CAM:特征图按权重加权求和,再通过 ReLU(仅保留正贡献区域),得到 CAM = ReLU(Σ(α_c * A_c))
  5. 可视化:将 CAM 上采样至原图尺寸,与原图叠加,展示模型关注区域。

代码示例(基于 PyTorch+ResNet18)

import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np# 1. 加载预训练模型(ResNet18)和预处理
model = models.resnet18(pretrained=True)
model.eval()  # 评估模式,禁用Dropout/BatchNorm随机化# 图像预处理(匹配ResNet输入要求)
preprocess = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 2. 定义Grad-CAM类(获取最后一个卷积层特征图和梯度)
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layer  # 最后一个卷积层(如resnet18的layer4)self.features = None  # 存储特征图self.grads = None     # 存储梯度# 注册前向钩子:获取特征图self.forward_hook = target_layer.register_forward_hook(self.save_features)# 注册反向钩子:获取梯度self.backward_hook = target_layer.register_backward_hook(self.save_grads)def save_features(self, module, input, output):self.features = output.detach()  # 特征图:[B, C, H, W]def save_grads(self, module, grad_input, grad_output):self.grads = grad_output[0].detach()  # 梯度:[B, C, H, W]def get_cam(self, target_class=None):# 1. 计算通道权重:梯度全局平均池化alpha = self.grads.mean(dim=(2, 3), keepdim=True)  # [B, C, 1, 1]# 2. 特征图加权求和cam = (alpha * self.features).sum(dim=1, keepdim=True)  # [B, 1, H, W]# 3. ReLU:仅保留正贡献cam = torch.relu(cam)# 4. 归一化到[0,1]cam = (cam - cam.min()) / (cam.max() - cam.min())return camdef remove_hooks(self):self.forward_hook.remove()self.backward_hook.remove()# 3. 加载图像并预处理
img_path = "dog.jpg"  # 替换为你的图像路径
img = Image.open(img_path).convert("RGB")
img_tensor = preprocess(img).unsqueeze(0)  # [1, 3, 224, 224]# 4. 初始化Grad-CAM(目标层为resnet18的layer4)
grad_cam = GradCAM(model, model.layer4)# 5. 前向传播:获取预测类别
output = model(img_tensor)
pred_class = output.argmax(dim=1).item()  # 预测类别(默认取概率最大类)
print(f"预测类别:{pred_class}(可替换为自定义目标类)")# 6. 反向传播:计算梯度(针对预测类别)
model.zero_grad()
output[:, pred_class].backward()  # 仅对预测类别的输出求导# 7. 生成CAM并可视化
cam = grad_cam.get_cam(target_class=pred_class)  # [1, 1, 7, 7](resnet18 layer4输出尺寸7x7)
grad_cam.remove_hooks()  # 移除钩子# 8. CAM上采样至原图尺寸(224x224)
from torch.nn.functional import interpolate
cam_up = interpolate(cam, size=(224, 224), mode='bilinear', align_corners=False)
cam_up = cam_up.squeeze().numpy()  # [224, 224]# 9. 原图与CAM叠加显示
img_np = np.array(img.resize((224, 224))) / 255.0  # 原图归一化
cam_heatmap = plt.cm.jet(cam_up)[:, :, :3]  # CAM转为热力图
overlay = 0.6 * img_np + 0.4 * cam_heatmap  # 叠加(原图60% + 热力图40%)# 绘制结果
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(img_np)
plt.title("原图")
plt.axis('off')plt.subplot(1, 3, 2)
plt.imshow(cam_up, cmap='jet')
plt.title("Grad-CAM(7x7→224x224)")
plt.axis('off')plt.subplot(1, 3, 3)
plt.imshow(overlay)
plt.title(f"叠加结果(预测类:{pred_class})")
plt.axis('off')plt.show()

关键说明

  • 目标层选择:需选择最后一个卷积层(如 ResNet 的layer4、VGG 的features[-1]),该层特征图既保留空间信息,又包含高层语义;

  • 自定义目标类:若需分析特定类别(非预测类),可将output[:, pred_class].backward()改为output[:, target_class].backward()(如target_class=100);

  • 可视化效果:热力图红色区域为模型关注的关键区域(如狗的头部、身体),解释模型预测依据。

@浙大疏锦行

http://www.dtcms.com/a/450092.html

相关文章:

  • 绵阳网站建设培训学校隐私空间
  • 淮安网站建设做北京电梯招标的网站
  • 专业企业网站建设定制百度如何做网站
  • Net-Tools工具包详解:Linux网络管理经典工具集
  • 极路由做网站无锡网站推广公司排名
  • registrateAPI——非空函数
  • 环境设计案例网站基于html5动画的网站
  • CCF编程能力等级认证GESP—C++4级—20250927
  • 网站收录率怎样建立自己网站多少钱
  • 电商平台网站设计公司企业建站搭建
  • 【数据结构】链栈的基本操作
  • 实战分享:股票数据API接口在量化分析中的应用与体验
  • 个人建设网站还要备案么wordpress建站详细教程视频
  • Vue2 和 Vue3 View
  • 乐趣做网站厦门做网站的公司
  • 使用jmeter做压力测试
  • [工作流节点15] 推送消息节点在企业内部通知中的应用实践
  • 热转印 东莞网站建设ui界面设计英文
  • 【数据结构学习篇】--树
  • Linux中驱动程序通过fasync异步通知应用程序的实现
  • MySQL索引优化:让查询快如闪电
  • 什么是营销型网站呢什么网站做新产品代理
  • 海沧建设网站多少jetpack报错 wordpress
  • 从零起步学习Redis || 第九章:缓存雪崩,缓存击穿,缓存穿透三大问题的成因及实战解决方案
  • 手机网站 微信链接网站建设工具
  • 网站建设年度总结客源通app下载
  • 欧美做暧网站jsp可以做网站吗
  • Variational Quantum Eigensolver笔记
  • 操作系统应用开发(二十四)RustDesk 404错误—东方仙盟筑基期
  • 网站菜单样式关于网站策划的文章