当前位置: 首页 > news >正文

钩子函数的作用(register_hook)

钩子函数仅在backward()时才会触发。其中,钩子函数接受梯度作为输入,返回操作后的梯度,操作后的梯度必须要输入的梯度同类型、同形状,否则报错。

主要功能包括:

  • 监控当前的梯度(不返回值);
  • 对当前的梯度进行操作,返回新的梯度以覆盖原梯度;
  • 在模型中对梯度进行监控或者修改。

案例 1:监控梯度值

import torch# 创建一个张量,并启用梯度追踪
x = torch.tensor([1.0], requires_grad=True)
y = x * 2# 定义钩子函数
def hook_fn(grad):'''作用:打印梯度'''print("Hook triggered, gradient:", grad)# 注册钩子:将钩子函数注册到x上,反向传播计算x梯度时自动触发钩子函数
x.register_hook(hook_fn)# 触发反向传播和钩子函数
y.backward()             

结果:

Hook triggered, gradient: tensor([2.])

案例 2:修改梯度值

import torch# 创建一个张量,并启用梯度追踪
x = torch.tensor([1.0], requires_grad=True)
y = x * 2# 定义钩子函数
def hook_fn(grad):'''作用:修改输入的梯度'''print('原梯度:',grad)return grad * 3# 注册钩子:将钩子函数注册到x上,反向传播计算x梯度时自动触发钩子函数
x.register_hook(hook_fn)# 触发反向传播和钩子函数
y.backward()          print("修改后的梯度:", x.grad)            

结果:

原梯度: tensor([2.])
修改后的梯度: tensor([6.])

案例 3:在模型中使用 register_hook

import torch
import torch.nn as nnmodel = nn.Linear(1, 1)
weight = model.weight # 模型权重# 定义钩子函数
def hook_fn(grad):'''作用:打印梯度'''print("Gradient of weight:", grad)# 注册钩子:将钩子函数注册到weight上,反向传播计算weight梯度时自动触发钩子函数
weight.register_hook(hook_fn)# 输入数据
x = torch.tensor([[1.0]])
target = torch.tensor([[3.0]])# 前向传播
output = model(x)
print(output)# 损失函数
loss = (output - target).pow(2)# 触发反向传播和钩子函数
loss.backward()           

结果:

Gradient of weight: tensor([[-6.1532]])

注意:
在实际使用中,必须使用clone()来确保梯度操作的安全性和计算图完整性,例如:

def hook_fn(grad):return grad.clone() * 3
  • 通过 grad.clone() 创建梯度副本后进行操作,所有修改仅作用于副本,不会触碰原始梯度存储。不采用克隆,直接对原始梯度进行操作,PyTorch 会检测到对计算图中张量的潜在原地修改(in-place operation),并抛出异常。
  • 不采用克隆,会破坏计算图路径,导致梯度回传中断或错误。

相关文章:

  • 2025-05-28 Python深度学习8——优化器
  • 破能所,入不二
  • GNU AS汇编器的.align对齐
  • 端午节互动网站
  • 力扣 215 .数组中的第K个最大元素
  • AMBA-AHB总线是怎么不依赖三态总线的?
  • 11.14 LangGraph检查点系统实战:AI Agent会话恢复率提升287%的企业级方案
  • 【网络编程】十八、Reactor模式
  • 2025年05月28日Github流行趋势
  • 农业光合参数反演专栏
  • kubernate解决 “cni0“ already has an IP address different from 10.244.0.1/24问题
  • Caddy如何在测试环境中使用IP地址配置HTTPS服务
  • bug: uniCloud 查询数组字段失败
  • HTTP Accept简介
  • linux系统(centos7为例)将jar配置成服务操作教程
  • 浏览器之禁止打开控制台【F12】
  • 网页前端开发(基础进阶1)
  • Transformer核心技术解析LCPO方法:精准控制推理长度的新突破
  • 计算机内存管理全解析:从基础原理到前沿技术(含分页/分段/置换算法/大页/NVM/CXL等技术详解
  • LVS的DR模式部署
  • 实时定量引物设计网站怎么做/seo推广是什么意思呢
  • redis做缓存的网站并发数/广告投放的方式有哪些
  • wordpress建立多站点/国际新闻直播
  • 西安西郊网站建设/百度题库
  • 济南网站建设公司制作/免费发广告的网站
  • 办公室装修设计费标准/百度seo排名优化公司