Pytorch的梯度控制
在之前的实验中遇到一些问题,因为之前计算资源有限,我就想着微调其中一部分参数做,于是我误打误撞使用了with torch.no_grad
,可是发现梯度传递不了,于是写下此文来记录梯度控制的两个方法与区别。
在PyTorch中,控制梯度计算对于模型训练和微调至关重要。这里区分两个常用方法:
1. tensor.requires_grad = False
- 目标: 单个张量(通常是模型参数
nn.Parameter
)。 - 行为:
- “参数冻结”:这个张量本身不会计算梯度 (
.grad
为None
)。 - “参数不更新”:优化器不会更新这个张量。
- “梯度可穿透”:如果它参与的运算的输入是
requires_grad=True
的,梯度仍然会通过这个运算传递给输入。它不阻碍梯度流向更早的可训练层。
- “参数冻结”:这个张量本身不会计算梯度 (
- 场景:
- 微调:冻结预训练模型的某些层,只训练其他层。
- 例子:
pretrained_layer.weight.requires_grad = False
2. with torch.no_grad():
- 目标: 一个代码块 (
with
语句块内部)。 - 行为:
- “全局梯度关闭”(块内):块内所有新创建的张量默认
requires_grad=False
。 - “不记录计算图”:块内的运算不被追踪,不构建反向传播所需的计算图。
- “梯度截断”:梯度流到这个块的边界就会停止,无法通过块内的操作继续反向传播。
- “全局梯度关闭”(块内):块内所有新创建的张量默认
- 场景:
- 模型评估/推理 (Inference/Evaluation):不需要梯度,节省内存和计算。
- 执行不需要梯度的任何计算。
- 例子:
with torch.no_grad():outputs = model(inputs)# ...其他评估代码
核心区别速记:
特性 | requires_grad=False | with torch.no_grad(): |
---|---|---|
谁不更新? | 这个参数自己 | (块内)没人更新 |
梯度能过吗? | 能过! | 不能过! (被截断) |
影响范围? | 单个张量 | 整个代码块 |
一句话总结:
- 想让某个参数不更新但梯度能流过,用
requires_grad=False
。 - 想让一段代码完全不计算梯度也不让梯度流过,用
with torch.no_grad()
。
搞清楚这两者的区别,能在PyTorch中更灵活地控制模型的训练过程!