当前位置: 首页 > news >正文

PyTorch 中.backward() 详解使用

1️、背景:自动求导机制(Autograd)

  • PyTorch 使用 动态图机制 (Dynamic Computation Graph)

    • 每一次 forward 操作都会在后台构建一张计算图(autograd graph)。
    • 计算图的节点:Tensor
    • 边:运算(Function
  • 调用 .backward() 时,PyTorch 会从 标量(loss)开始,沿着计算图反向传播,自动计算各个参数的梯度,并存储在 tensor.grad 中。


2️、基本用法

import torch# 定义张量,开启梯度跟踪
x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 1  # y = x^2 + 3x + 1# 反向传播
y.backward()print(x.grad)  # dy/dx = 2x + 3 = 7

3️、.backward() 的常见参数

(1) gradient=None

  • 适用于 标量张量(只含一个值)。
  • 如果 tensor 不是标量(多元素张量),必须传入 gradient 参数(同 shape 的权重),告诉 PyTorch 怎么把向量 → 标量。

例子:

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x**2
# y = [1,4,9],不是标量
# 如果直接 y.backward() 会报错# 传入 gradient 向量,等价于对 sum(y) 求导
y.backward(torch.tensor([1.0, 1.0, 1.0]))
print(x.grad)  # [2,4,6]

(2) retain_graph=False

  • 默认情况下,backward() 执行后会 释放计算图(节省内存)。
  • 如果你需要对 同一个图多次反向传播,必须设 retain_graph=True
x = torch.tensor(2.0, requires_grad=True)
y = x**3y.backward(retain_graph=True)
print(x.grad)  # 12 (dy/dx=3x^2=12)y.backward()   # 第二次 backward,如果没有 retain_graph 会报错
print(x.grad)  # 24 (累积梯度)

(3) create_graph=False

  • 是否在反向传播时 构建计算图(用于高阶导数)。
  • 默认 False(只算一次梯度)。
  • 如果要继续对梯度求导,就需要 create_graph=True
x = torch.tensor(2.0, requires_grad=True)
y = x**3dy_dx = torch.autograd.grad(y, x, create_graph=True)[0]
# dy/dx = 12d2y_dx2 = torch.autograd.grad(dy_dx, x)[0]
print(d2y_dx2)  # 12

4️、梯度累积机制

PyTorch 中 .grad累积的,不会自动清零。

x = torch.tensor(2.0, requires_grad=True)y1 = x**2
y1.backward()
print(x.grad)  # 4y2 = 3*x
y2.backward()
print(x.grad)  # 4 + 3 = 7# 解决方法:每次反向传播前清零
x.grad.zero_()

在训练循环中,通常这样写:

optimizer.zero_grad()  # 清空旧梯度
loss.backward()        # 计算新梯度
optimizer.step()       # 更新参数

5️、.backward()torch.autograd.grad

  • .backward():会把梯度 存储到 tensor.grad
  • torch.autograd.grad()返回梯度值,不会自动累积到 .grad

例子:

x = torch.tensor(2.0, requires_grad=True)
y = x**3# 用 backward
y.backward()
print(x.grad)  # 12# 用 autograd.grad
dy_dx = torch.autograd.grad(y, x)[0]
print(dy_dx)  # 12
print(x.grad) # 12 (还是累积的)

6️、.backward() 的典型使用场景

  1. 训练神经网络

    for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()
    
  2. 自定义梯度计算(比如物理约束 loss)

    x = torch.randn(3, requires_grad=True)
    y = (x**2).sum()
    y.backward()
    print(x.grad)  # 2x
    
  3. 高阶导数

    x = torch.tensor(1.0, requires_grad=True)
    y = x**3
    dy_dx = torch.autograd.grad(y, x, create_graph=True)[0]
    d2y_dx2 = torch.autograd.grad(dy_dx, x)[0]
    print(d2y_dx2)  # 6
    

7️、常见坑

  1. loss 不是标量 → 必须传 gradient 参数。
  2. 忘记清零梯度 → 梯度会累积,导致训练异常。
  3. 重复 backward → 必须 retain_graph=True
  4. requires_grad=False 的 tensor → 不会求导。
  5. in-place 操作(如 x += 1)可能破坏计算图,导致错误。

8、综合示例

示例 1:简单函数 + backward 验证

函数:

y=(x1⋅x2+x3)2 y = (x_1 \cdot x_2 + x_3)^2 y=(x1x2+x3)2

import torch# 定义输入
x1 = torch.tensor(2.0, requires_grad=True)
x2 = torch.tensor(3.0, requires_grad=True)
x3 = torch.tensor(1.0, requires_grad=True)# 前向计算
y = (x1 * x2 + x3) ** 2
print("y =", y.item())# 反向传播
y.backward()# 查看梯度
print("dy/dx1 =", x1.grad.item())  # 2 * (x1*x2 + x3) * x2
print("dy/dx2 =", x2.grad.item())  # 2 * (x1*x2 + x3) * x1
print("dy/dx3 =", x3.grad.item())  # 2 * (x1*x2 + x3)

结果

y = 49
dy/dx1 = 42
dy/dx2 = 28
dy/dx3 = 14

对应梯度推导:

  • dy/dx1=2(x1x2+x3)⋅x2=2(2∗3+1)∗3=42dy/dx_1 = 2(x_1x_2 + x_3) \cdot x_2 = 2(2*3+1)*3 = 42dy/dx1=2(x1x2+x3)x2=2(23+1)3=42
  • dy/dx2=2(x1x2+x3)⋅x1=2(6+1)∗2=28dy/dx_2 = 2(x_1x_2 + x_3) \cdot x_1 = 2(6+1)*2 = 28dy/dx2=2(x1x2+x3)x1=2(6+1)2=28
  • dy/dx3=2(x1x2+x3)=14dy/dx_3 = 2(x_1x_2 + x_3) = 14dy/dx3=2(x1x2+x3)=14

.backward() 自动完成了这些链式法则计算。


示例 2:神经网络训练一个回归任务

任务:用一个 简单的全连接网络 拟合函数 y=2x+3y = 2x + 3y=2x+3

import torch
import torch.nn as nn
import torch.optim as optim# 构造数据
x = torch.linspace(-5, 5, 100).unsqueeze(1)  # shape [100,1]
y = 2 * x + 3 + 0.1 * torch.randn_like(x)    # 加点噪声**定义网络**
model = nn.Sequential(nn.Linear(1, 10),nn.ReLU(),nn.Linear(10, 1)
)**损失函数 + 优化器**
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)**训练**
for epoch in range(200):optimizer.zero_grad()          # 清空梯度y_pred = model(x)              # 前向loss = criterion(y_pred, y)    # 计算lossloss.backward()                # 🔹 反向传播,计算梯度optimizer.step()                # 更新参数if epoch % 50 == 0:print(f"Epoch {epoch}, Loss = {loss.item():.4f}")

训练过程中:

  • .backward() 会自动计算网络所有参数的梯度(通过计算图 + 链式法则)。
  • optimizer.step() 利用梯度更新参数。

最终模型能学到近似 y=2x+3y = 2x + 3y=2x+3


对比两例

  • 示例 1:展示了 .backward()简单函数 中的梯度传播过程。
  • 示例 2:展示了 .backward()实际深度学习训练 中的应用,自动处理整个网络的梯度。

总结

.backward() 是 PyTorch 自动求导的核心函数:

  • 默认从 标量 loss 反向传播,计算所有叶子节点的梯度,存到 .grad

  • 关键参数:

    • gradient:非标量情况需要指定权重。
    • retain_graph:是否保留计算图,支持多次 backward。
    • create_graph:是否构建高阶导数的计算图。
  • 梯度会 累积,必须在训练循环里手动清零。

  • .backward()torch.autograd.grad 可以互补使用。



文章转载自:

http://kKdRwrLV.rynqh.cn
http://dtf658hG.rynqh.cn
http://I0deRV3s.rynqh.cn
http://70NzIiEl.rynqh.cn
http://Vs5hbGmH.rynqh.cn
http://fIAr9enk.rynqh.cn
http://DnZwqBq0.rynqh.cn
http://vnLTOFP1.rynqh.cn
http://fEBbXxzd.rynqh.cn
http://VDJjA7k8.rynqh.cn
http://KV0l8Eb7.rynqh.cn
http://a8OKhWmi.rynqh.cn
http://g5uIbTyR.rynqh.cn
http://K34S1b9x.rynqh.cn
http://efiqkOWG.rynqh.cn
http://rBpNi4dE.rynqh.cn
http://Tbxk6pb5.rynqh.cn
http://0IN85ovk.rynqh.cn
http://oiCJg1CC.rynqh.cn
http://tATEpTEp.rynqh.cn
http://41RTkD28.rynqh.cn
http://Ra5y4zlV.rynqh.cn
http://nY04Q2sq.rynqh.cn
http://ZFKe2RXE.rynqh.cn
http://EniiK3MY.rynqh.cn
http://oz3MZm4V.rynqh.cn
http://ncYQTJdf.rynqh.cn
http://eOjsqeGc.rynqh.cn
http://HJ6ZuLRe.rynqh.cn
http://ji0y4Y0J.rynqh.cn
http://www.dtcms.com/a/368175.html

相关文章:

  • conda配置pytorch虚拟环境
  • Conda环境隔离和PyCharm配置,完美同时运行PaddlePaddle和PyTorch
  • PyTorch训练循环详解:深入理解forward()、backward()和optimizer.step()
  • PyTorch 训练显存越跑越涨:隐式保留计算图导致 OOM
  • PyTorch图像数据转换为张量(Tensor)并进行归一化的标准操作
  • 图像去雾:从暗通道先验到可学习融合——一份可跑的 PyTorch 教程
  • EN-DC和CA的联系与区别
  • python + Flask模块学习 1 基础用法
  • 【Flask】测试平台中,记一次在vue2中集成编辑器组件tinymce
  • 【分享】基于百度脑图,并使用Vue二次开发的用例脑图编辑器组件
  • 【Python】QT(PySide2、PyQt5):点击不同按钮显示不同页面
  • flask的使用
  • Qt添加图标资源
  • 配置WSL2的Ubuntu接受外部设备访问
  • 产线相机问题分析思路
  • VisionPro联合编程相机拍照 九点标定实战
  • c++工程如何提供http服务接口
  • Linux查看相机支持帧率和格式
  • 必知!机器人的分类与应用:RPA、人形与工业机器人
  • 相机刮除拜尔阵列
  • 关于Homebrew:Mac快速安装Homebrew
  • 微信小程序一个页面同时存在input和textarea,bindkeyboardheightchange相互影响
  • mac怎么安装uv工具
  • python库 Py2app 的详细使用(将 Python 脚本变为 MacOS 独立软件包)
  • AmbiSSL
  • 【高分论文密码】大尺度空间模拟与不确定性分析及数字制图技术应用
  • MacOS 通过Homebrew 安装nvm
  • 【NotePad++设置自定义宏】
  • baml:为提示工程注入工程化能力的Rust类型安全AI框架详解
  • 【详细指导】多文档界面(MDI)的应用程序-图像处理