当前位置: 首页 > news >正文

关键词解释:梯度下降法(Gradient Descent)

梯度下降法(Gradient Descent)是优化中最基础且核心的算法之一。无论是一维还是多维问题,其核心思想一致:沿着目标函数梯度的反方向更新参数,以逐步逼近极小值点。下面我们从一维多维两个角度详细解析,并辅以数学公式与代码示例。


一、基本思想回顾

给定目标函数 $ f(\mathbf{x}) $,其中 $\mathbf{x} \in \mathbb{R}^n$ 是参数向量,梯度下降的更新规则为:

$ \mathbf{x}_{t+1} = \mathbf{x}_t - \eta \nabla f(\mathbf{x}_t) $

  • $\eta > 0$:学习率(步长)
  • $\nabla f(\mathbf{x}_t)$:函数在 $\mathbf{x}_t$ 处的梯度(一阶导数向量)

关键区别

  • 一维:梯度退化为普通导数(标量)
  • 多维:梯度是偏导数组成的向量

二、一维梯度下降(标量参数)

1. 场景

优化一个单变量函数,例如: $ f(x) = x^2,\quad f(x) = (x - 3)^2 + 5 $

2. 数学形式

  • 梯度:$ f'(x) = \frac{df}{dx} $
  • 更新规则: $ x_{t+1} = x_t - \eta \cdot f'(x_t) $

3. 示例:最小化$ f(x) = x^2 $

  • 导数:$ f'(x) = 2x $
  • 更新:$ x_{t+1} = x_t - \eta \cdot 2x_t = x_t (1 - 2\eta) $

收敛条件$ |1 - 2\eta| < 1 \Rightarrow 0 < \eta < 1 $

4. Python 代码演示(一维)

import numpy as np
import matplotlib.pyplot as pltdef f(x):return x ** 2def df_dx(x):return 2 * x# 参数
x = 5.0          # 初始点
lr = 0.1         # 学习率
steps = 20history = [x]
for _ in range(steps):grad = df_dx(x)x = x - lr * gradhistory.append(x)# 可视化
x_vals = np.linspace(-5, 5, 100)
plt.plot(x_vals, f(x_vals), 'b-', label='$f(x)=x^2$')
plt.scatter(history, [f(x) for x in history], c='red', s=30, zorder=5)
plt.plot(history, [f(x) for x in history], 'r--', alpha=0.7)
plt.title('1D Gradient Descent')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.grid(True)
plt.show()


三、多维梯度下降(向量参数)

1. 场景

优化多变量函数,例如:

  • 线性回归:$ f(\mathbf{w}, b) = \frac{1}{m} \sum_{i=1}^m (y_i - (\mathbf{w}^\top \mathbf{x}_i + b))^2 $
  • 神经网络损失函数

2. 数学形式

设参数为向量 $\boldsymbol{\theta} = [\theta_1, \theta_2, \dots, \theta_n]^\top$

  • 梯度:
    $ \nabla f(\boldsymbol{\theta}) = \begin{bmatrix} \frac{\partial f}{\partial \theta_1} \ \frac{\partial f}{\partial \theta_2} \ \vdots \ \frac{\partial f}{\partial \theta_n} \end{bmatrix} \in \mathbb{R}^n $

  • 更新规则(逐元素): $ \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta \cdot \nabla f(\boldsymbol{\theta}_t) $

3. 示例:最小化 $ f(x, y) = x^2 + y^2 $

  • 梯度:
    $ \nabla f = \begin{bmatrix} 2x \ 2y \end{bmatrix} $
  • 更新: $ x_{t+1} = x_t - \eta \cdot 2x_t \ y_{t+1} = y_t - \eta \cdot 2y_t $

最小值在$(0, 0)$

4. Python 代码演示(二维)

import numpy as np
import matplotlib.pyplot as pltdef f(x, y):return x**2 + y**2def grad_f(x, y):return np.array([2*x, 2*y])# 初始化
theta = np.array([3.0, 2.0])  # [x, y]
lr = 0.1
steps = 20history = [theta.copy()]
for _ in range(steps):grad = grad_f(theta[0], theta[1])theta = theta - lr * gradhistory.append(theta.copy())history = np.array(history)# 绘制等高线 + 路径
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)plt.contour(X, Y, Z, levels=20, cmap='viridis')
plt.plot(history[:, 0], history[:, 1], 'ro-', label='GD Path')
plt.scatter(0, 0, c='red', marker='*', s=200, label='Minimum')
plt.title('2D Gradient Descent on $f(x,y)=x^2+y^2$')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True)
plt.axis('equal')
plt.show()

四、一维 vs 多维对比总结

特性一维梯度下降多维梯度下降
参数标量 $x$向量 $\boldsymbol{\theta} \in \mathbb{R}^n$
梯度导数$f'(x)$(标量)梯度向量 $\nabla f$($n$ 维)
更新$x \leftarrow x - \eta f'(x)$$\boldsymbol{\theta} \leftarrow \boldsymbol{\theta} - \eta \nabla f$
几何意义在曲线上左右移动在曲面/超曲面上沿最陡下降方向移动
应用教学示例、简单优化机器学习、深度学习、工程优化

五、注意事项(多维特有)

  1. 特征尺度差异
    若各维度量纲不同(如 $x_1$ 是身高(cm),$x_2$是收入(万元)),会导致梯度方向扭曲 → 需标准化(Standardization)

  2. 病态条件(Ill-conditioning)
    Hessian 矩阵条件数大 → 梯度下降呈“锯齿状”震荡 → 可用 动量(Momentum)二阶方法(如牛顿法)

  3. 局部极小值与鞍点
    多维非凸函数可能存在多个极小值或鞍点(梯度为零但非极值)→ Adam 等优化器更鲁棒


六、扩展:批量梯度下降(机器学习中的多维应用)

在线性回归中,损失函数为:

$ J(\mathbf{w}) = \frac{1}{2m} |\mathbf{Xw} - \mathbf{y}|^2 $

梯度为:

$ \nabla J = \frac{1}{m} \mathbf{X}^\top (\mathbf{Xw} - \mathbf{y}) $

更新:

$ \mathbf{w} \leftarrow \mathbf{w} - \eta \cdot \nabla J $

这就是多维梯度下降在机器学习中的典型应用


总结

  • 一维:理解原理的“玩具模型”,直观展示学习率影响
  • 多维:真实世界的优化问题,需处理向量、矩阵、尺度、收敛性等复杂性
  • 核心不变:始终沿负梯度方向更新参数

💡 记住:无论维度多少,梯度下降的本质都是——“往最陡的下坡方向走一步”

http://www.dtcms.com/a/561270.html

相关文章:

  • 做外贸的网站哪个好湖南人文科技学院
  • deadbeef播放器歌词插件
  • 网站推广有什么好处咨询公司招聘条件
  • 网站定位授权开启权限怎么做精准营销模式
  • Flutter 开发环境配置教程
  • Go Gorm 深度解析:从内部原理到实战避坑指南
  • 保定企业建网站房产网站运营方案
  • 机械动力的能力
  • 山西省旅游网站建设分析廊坊网站制作网站
  • 【YashanDB认证】之二:Docker部署一体YashanDB(YDC,YCM)
  • C语言刷题(一)
  • 电子电气架构(EEA)最新调研-5
  • 【软考架构】案例分析-对比MySQL查询缓存与Memcached
  • 「经典图形题」集合 | C/C++
  • IT4IT是由The Open Group提出的面向数字化转型的IT管理参考架构框架
  • 学校网站怎么做的好南翔做网站公司
  • 解决 CentOS 8 报错:Failed to download metadata for repo ‘BaseOS‘
  • VS Code集成googletest-C/C++单元测试Windows
  • Vue 图片性能优化双剑客:懒加载与自动压缩实战指南
  • 网站之家查询qq空间网站是多少
  • Elasticsearch 与 Faiss 联合驱动自动驾驶场景检索:高效语义匹配 PB 级视频数据
  • 短租网站开发学术ppt模板免费
  • 设计模式——单例模式(singleton)
  • 【计算机软件资格考试】软考综合知识题高频考题及答案解析1
  • 计算机网络自顶向下方法25——运输层 TCP流量控制 连接管理 “四次挥手”的优化
  • LeetCode【高频SQL基础50题】
  • 清远做网站的有哪些wordpress判断浏览器
  • 自己做的网站怎样才有网址浏览找人做网站域名怎么过户
  • JavaScript中的闭包:原理与实战
  • 怎么看一个网站是否被k怎么找项目