当前位置: 首页 > news >正文

【PyTorch】PyTorch 自动微分与完整手动实现对比

PyTorch 自动微分与完整手动实现对比

  • 一、手动实现前向传播
    • 1. 问题设定
    • 2.代码实现
    • 3. 手动计算梯度(反向传播)
    • 4. 参数更新(手动实现 optimizer.step ())
  • 二、 PyTorch 自动微分
  • 三、 总结

一、手动实现前向传播

1. 问题设定

在这里插入图片描述

2.代码实现

# 输入数据和标签
x = [1.0, 2.0]  # 输入特征
y = 3.0         # 真实标签# 模型参数(随机初始化)
w = [0.5, 0.5]  # 权重
b = 0.1         # 偏置# 前向传播计算预测值
y_pred = w[0] * x[0] + w[1] * x[1] + b# 计算损失(MSE)
loss = 0.5 * (y_pred - y) ** 2print(f"预测值: {y_pred}, 损失: {loss}")

3. 手动计算梯度(反向传播)

在这里插入图片描述

# 计算中间梯度
dL_dy_pred = y_pred - y  # ∂L/∂ŷ# 计算参数梯度
dL_dw1 = dL_dy_pred * x[0]  # ∂L/∂w1 = ∂L/∂ŷ * ∂ŷ/∂w1
dL_dw2 = dL_dy_pred * x[1]  # ∂L/∂w2 = ∂L/∂ŷ * ∂ŷ/∂w2
dL_db = dL_dy_pred         # ∂L/∂b = ∂L/∂ŷ * ∂ŷ/∂bprint(f"梯度: ∂L/∂w1 = {dL_dw1}, ∂L/∂w2 = {dL_dw2}, ∂L/∂b = {dL_db}")

4. 参数更新(手动实现 optimizer.step ())

使用梯度下降更新参数:

learning_rate = 0.01# 更新参数
w[0] = w[0] - learning_rate * dL_dw1
w[1] = w[1] - learning_rate * dL_dw2
b = b - learning_rate * dL_dbprint(f"更新后的参数: w = {w}, b = {b}")

二、 PyTorch 自动微分

import torch
import torch.nn as nn# ===== PyTorch 自动微分实现 =====
# 创建模型和数据
net = nn.Sequential(nn.Linear(2, 1))
net[0].weight.data = torch.tensor([[0.5, 0.5]])  # 初始权重
net[0].bias.data = torch.tensor([0.1])          # 初始偏置x_tensor = torch.tensor([[1.0, 2.0]], requires_grad=False)
y_tensor = torch.tensor([[3.0]])# 前向传播
y_pred_tensor = net(x_tensor)
loss_tensor = nn.MSELoss()(y_pred_tensor, y_tensor)# 反向传播
loss_tensor.backward()print("\nPyTorch 自动微分结果:")
print(f"预测值: {y_pred_tensor.item()}, 损失: {loss_tensor.item()}")
print(f"梯度: ∂L/∂w1 = {net[0].weight.grad[0, 0].item()}, "f"∂L/∂w2 = {net[0].weight.grad[0, 1].item()}, "f"∂L/∂b = {net[0].bias.grad.item()}")# ===== 手动实现结果 =====
print("\n手动实现结果:")
print(f"预测值: {y_pred}, 损失: {loss}")
print(f"梯度: ∂L/∂w1 = {dL_dw1}, ∂L/∂w2 = {dL_dw2}, ∂L/∂b = {dL_db}")

输出对比(假设学习率为 0.01)

PyTorch 自动微分结果:
预测值: 1.6, 损失: 0.98
梯度: ∂L/∂w1 = 1.4, ∂L/∂w2 = 2.8, ∂L/∂b = 1.4手动实现结果:
预测值: 1.6, 损失: 0.98
梯度: ∂L/∂w1 = 1.4, ∂L/∂w2 = 2.8, ∂L/∂b = 1.4

三、 总结

计算图的本质
无论使用 PyTorch 还是手动计算,梯度的流动路径都是由数学公式决定的。PyTorch 只是自动帮你构建了这个路径。
链式法则的核心作用
梯度从损失函数开始,通过链式法则逐层传递到每个参数。例如:

∂L/∂w1 = (∂L/∂ŷ) * (∂ŷ/∂w1) = (ŷ - y) * x1

PyTorch 的自动微分
当你调用 loss.backward() 时,PyTorch 会:

  • 自动追踪从 loss 到 weight 和 bias 的所有操作路径。
  • 对每个操作应用链式法则计算局部梯度。
  • 将最终梯度存储在 weight.grad 和 bias.grad 中。
http://www.dtcms.com/a/273341.html

相关文章:

  • vue3 element plus table 使用固定列,滑动滚动会错位、固定列层级异常、滑动后固定列的内容看不到了
  • Java多线程 V1
  • AIStarter 3.2.0正式上线!高速下载+离线导入+一键卸载新功能详解【附完整使用教程】✅ 帖子正文(字数:约 400 字)
  • 静态路由综合实验
  • WiFi技术深度研究报告:从基础原理到组网应用与未来演进
  • python+django/flask基于微信小程序的农产品管理与销售APP系统
  • CTFshow-PWN-栈溢出(pwn62-pwn64)
  • JAVA面试宝典 -《新潮技术:协程与响应式编程实践》
  • 【Ubuntu】编译sentencepiece库
  • next.js打包后的前端资源如何进行部署和访问,为什么没有index.html
  • Vue响应式原理六:Vue3响应式原理
  • Java 17 新特性解析:密封类与模式匹配的完美协作
  • 01背包问题总结
  • 三维旋转沿轴分解
  • AWS ECS任务角色一致性检查与自动修复工具完全指南
  • LVGL学习笔记-----进度条控件(lv_bar)
  • Java结构型模式---桥接模式
  • 什么?不知道 MyBatisPlus 多数据源(动态数据源)干什么的,怎么使用,看这篇文章就够了。
  • AI探索 | 豆包智能助手跟扣子空间(AI办公助手)有什么区别
  • Ranger框架的发展历程
  • Windows系统DLL、运行库、DirectX等DLL丢失等异常状态
  • 数组的应用示例
  • 【Python进阶篇 面向对象程序设计(7) Python操作数据库】
  • 《测试开发:从技术角度提升测试效率与质量》
  • 《Revisiting Generative Replay for Class Incremental Object Detection》阅读笔记
  • 3D lidar目标跟踪
  • PyTorch自动微分:从基础到实战
  • Linux C 文件基本操作
  • 【Java并发编程】AQS(AbstractQueuedSynchronizer)抽象同步器核心原理
  • 飞算科技:以原创技术赋能电商企业数字化转型