钩子函数的作用(register_hook)
钩子函数仅在backward()
时才会触发。其中,钩子函数接受梯度作为输入,返回操作后的梯度,操作后的梯度必须要输入的梯度同类型、同形状,否则报错。
主要功能包括:
- 监控当前的梯度(不返回值);
- 对当前的梯度进行操作,返回新的梯度以覆盖原梯度;
- 在模型中对梯度进行监控或者修改。
案例 1:监控梯度值
import torch# 创建一个张量,并启用梯度追踪
x = torch.tensor([1.0], requires_grad=True)
y = x * 2# 定义钩子函数
def hook_fn(grad):'''作用:打印梯度'''print("Hook triggered, gradient:", grad)# 注册钩子:将钩子函数注册到x上,反向传播计算x梯度时自动触发钩子函数
x.register_hook(hook_fn)# 触发反向传播和钩子函数
y.backward()
结果:
Hook triggered, gradient: tensor([2.])
案例 2:修改梯度值
import torch# 创建一个张量,并启用梯度追踪
x = torch.tensor([1.0], requires_grad=True)
y = x * 2# 定义钩子函数
def hook_fn(grad):'''作用:修改输入的梯度'''print('原梯度:',grad)return grad * 3# 注册钩子:将钩子函数注册到x上,反向传播计算x梯度时自动触发钩子函数
x.register_hook(hook_fn)# 触发反向传播和钩子函数
y.backward() print("修改后的梯度:", x.grad)
结果:
原梯度: tensor([2.])
修改后的梯度: tensor([6.])
案例 3:在模型中使用 register_hook
import torch
import torch.nn as nnmodel = nn.Linear(1, 1)
weight = model.weight # 模型权重# 定义钩子函数
def hook_fn(grad):'''作用:打印梯度'''print("Gradient of weight:", grad)# 注册钩子:将钩子函数注册到weight上,反向传播计算weight梯度时自动触发钩子函数
weight.register_hook(hook_fn)# 输入数据
x = torch.tensor([[1.0]])
target = torch.tensor([[3.0]])# 前向传播
output = model(x)
print(output)# 损失函数
loss = (output - target).pow(2)# 触发反向传播和钩子函数
loss.backward()
结果:
Gradient of weight: tensor([[-6.1532]])
注意:
在实际使用中,必须使用clone()
来确保梯度操作的安全性和计算图完整性,例如:
def hook_fn(grad):return grad.clone() * 3
- 通过
grad.clone()
创建梯度副本后进行操作,所有修改仅作用于副本,不会触碰原始梯度存储。不采用克隆,直接对原始梯度进行操作,PyTorch 会检测到对计算图中张量的潜在原地修改(in-place operation),并抛出异常。 - 不采用克隆,会破坏计算图路径,导致梯度回传中断或错误。