@once_differentiable 自定义算子的用处
背景:在看大佬们pytorch 自定义算子时,发现在backward()用来once_differentiable修饰。
说明一下:
@once_differentiable是 PyTorch 中一个用于修饰自定义自动求导函数(Autograd Function)的 backward方法的装饰器。
from torch.autograd.function import once_differentiableclass GOF_Function(Function):@staticmethoddef forward(ctx, weight, gaborFilterBank):ctx.save_for_backward(weight, gaborFilterBank)output = _C.gof_forward(weight, gaborFilterBank) # 调用C++扩展return output@staticmethod@once_differentiable # 修饰backward方法def backward(ctx, grad_output):weight, gaborFilterBank = ctx.saved_tensorsgrad_weight = _C.gof_backward(grad_output, gaborFilterBank) # 调用C++扩展return grad_weight, None
💡 简单来说:@once_differentiable就像一个声明,告诉 PyTorch:“这个 backward方法到此为止,一阶导数(_C.gof_backward)我提供了,但别再试图对它求高阶导了”。这在自定义算子涉及不可微的 C++ 代码、复杂操作或仅为推理优化时非常有用。
但是一阶导数(_C.gof_backward)仍需正确实现, backward方法仍然必须正确计算输出梯度相对于所有可微分输入的一阶导数。
这使得 backward方法就像一个计算梯度的“黑盒”:PyTorch 会调用它获取一阶梯度,但不会尝试对其内部过程再次进行求导(即计算二阶导数)。