当前位置: 首页 > news >正文

模型训练和推理

    • 训练时需要梯度,推理时不需要
    • 怎么理解“梯度”?
    • 计算图以及前向后向传播

训练时需要梯度,推理时不需要

阶段是否计算梯度是否反向传播是否更新参数用例写法
训练loss 训练默认即可,requires_grad=True
推理采样、预测、部署@torch.inference_mode()with torch.no_grad()
  • 训练阶段必须开启梯度计算

    • 计算 loss(损失函数)
    • 然后通过 loss.backward()反向传播(backpropagation)
    • 更新模型参数(optimizer.step()
  • 推理阶段(inference)不需要梯度计算,关闭它可以节省内存、提高速度

    • 只需要执行 forward,得到模型输出(如预测轨迹、采样结果)
    • 不再需要 loss,也不需要更新模型参数

@torch.inference_mode() 是 PyTorch 中用于 推理模式(inference mode) 的一个装饰器,主要功能是:临时关闭梯度计算(比 torch.no_grad() 更高效),用于模型推理阶段,加快速度、降低显存占用。
它和 @torch.no_grad() 类似,但更彻底:

  • torch.no_grad() 禁用梯度计算(不会构建计算图)
  • torch.inference_mode() 也禁用梯度计算,但还能避免某些内部缓冲区的额外开销,性能更好
@torch.inference_mode()
def predict(model, inputs):
   return model(inputs)

在这里插入图片描述

  • PyTorch 会在每次 forward 过程中,构建一棵 计算图(computation graph),记录每一步的操作,方便后面 loss.backward() 自动求导。
    • 一旦调用 loss.backward(),它会从最后一层反推回去,自动算出所有参数的梯度。
    • @torch.inference_mode()with torch.no_grad() 会告诉 PyTorch:我只是 forward 看看结果,不要帮我建计算图了!

扩展:How Computational Graphs are Executed in PyTorch

怎么理解“梯度”?

可以用一个简单直觉的比喻,把模型看成一个“函数机器”:它输入是数据(如图片、状态),输出是预测结果(如轨迹、控制信号)。

梯度 = 模型输出对参数的敏感程度(变化率)

比如:模型预测错了,就会计算:

loss = 模型输出 - 真实值

此时我们想知道:

如果我改变模型的参数,loss 会变大还是变小?

这就需要计算 loss 对模型参数的导数 —— 这就是梯度

举个例子

loss = (y_pred - y_true)**2

我们希望让 loss 趋近于 0。那我们就问:

  • loss模型参数 θ 的梯度是多少?
  • 梯度大 -> 表示参数的变化对 loss 影响大
  • 梯度小 -> 表示参数已经趋于最优了

然后 用这些梯度反过来更新模型

θ_new = θ_old - learning_rate * gradient

这就是 “梯度下降” 的核心思想。

用 PyTorch 来举个最简单的例子:

import torch

# 模拟一个参数 θ
theta = torch.tensor([2.0], requires_grad=True)

# 输入数据
x = torch.tensor([3.0])

# Forward:计算 y = theta * x
y = theta * x

# 假设目标输出是 y_true
y_true = torch.tensor([10.0])

# 计算 loss
loss = (y - y_true) ** 2

# 反向传播
loss.backward()

# 查看梯度
print("梯度:", theta.grad)  # 显示 ∂loss/∂theta  # 梯度: tensor([-48.])

说明此时:

  • 当前 θ 是 2 时,lossθ 的导数是 -48,代表 “要往更大的方向调 θ

计算图以及前向后向传播

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

扩展:PyTorch – Computational graph and Autograd with Pytorch

相关文章:

  • mysql8安装后没有自动生成登录密码
  • frameworks 之屏幕旋转
  • 【从零开始学习计算机科学】操作系统(五)处理器调度
  • JAVASE(五)
  • 垃圾收集算法与收集器
  • vue2:表单的动态校验和静态校验
  • 前端开发中的常见设计模式:全面解析与实践
  • Linux Shell 脚本编程极简入门指南
  • 服务器数据恢复—预防服务器故障,搞定服务器故障数据恢复
  • BT-Basic函数之首字母D
  • git commit messege 模板设置 (规范化管理git)
  • Python学习第十二天
  • 大模型在甲状腺癌诊疗全流程预测及方案制定中的应用研究
  • 台风信息查询API:数据赋能,守护安全
  • css中的浮动
  • 【QT5 Widgets示例】记事本:(三)功能实现
  • 2012. 数组美丽值求和【动态规划】
  • 学习threejs,使用LatheGeometry旋转体(榫卯体)几何体
  • texstudio: 编辑器显示行号+给PDF增加行号
  • 大数据实时分析:ClickHouse、Doris、TiDB 对比分析
  • 外交部回应西班牙未来外交战略:愿与之一道继续深化开放合作
  • 陕南多地供水形势严峻:有的已呼吁启用自备水井
  • 正荣地产:公司控股股东已获委任联合清盘人
  • 商务部:自5月7日起对原产于印度的进口氯氰菊酯征收反倾销税
  • 罗马尼亚总理乔拉库宣布辞职
  • 厦大历史系教授林汀水辞世,曾参编《中国历史地图集》