当前位置: 首页 > news >正文

常用优化器的原理及工作机制详解

1. 梯度下降(Gradient Descent, GD)

原理
通过计算整个数据集的平均梯度来更新参数,公式为:
θ t + 1 = θ t − η ⋅ ∇ J ( θ t ) \theta_{t+1} = \theta_t - \eta \cdot \nabla J(\theta_t) θt+1=θtηJ(θt)
其中 η \eta η 是学习率。优点是更新方向准确,但计算成本高,适合小数据集。

特点

  • 计算成本高(需遍历全部数据)
  • 可能陷入局部极小值
  • 对于凸函数保证收敛

Python代码

def gradient_descent(x, y, learning_rate=0.01, epochs=1000):
    """
    使用梯度下降法(Gradient Descent)优化线性回归模型的参数。

    参数:
    x (numpy.ndarray): 输入特征数据,形状为 (m,),其中 m 是样本数量。
    y (numpy.ndarray): 目标值数据,形状为 (m,),其中 m 是样本数量。
    learning_rate (float): 学习率,控制参数更新的步长。
    epochs (int): 迭代次数,即训练轮数。

    返回:
    theta (float): 优化后的模型参数。
    """
    m = len(y)  # 样本数量
    theta = np.random.randn()  # 随机初始化参数 theta

    for epoch in range(epochs):  # 迭代训练
        y_pred = theta * x  # 使用当前参数的预测值
        # 计算预测值与真实值之间的误差,并求梯度
        gradient = -2 * np.mean(x * (y - y_pred))
        # 更新参数 theta
        theta -= learning_rate * gradient

    return theta

#使用PyTorch框架
for inputs, labels in entire_dataset:
    # 清零梯度,防止梯度累积
    optimizer.zero_grad()
    # 前向传播,获取模型的输出
    outputs = model(inputs)
    # 计算损失
    loss = criterion(outputs, labels)
    # 反向传播,计算梯度
    loss.backward()
    
    # 手动实现全批量梯度累积并更新参数
    with torch.no_grad():
        for param in model.parameters():
            if param.grad is not None:  # 检查梯度是否存在
                param -= lr * param.grad


2. 随机梯度下降(Stochastic Gradient Descent, SGD)

原理
每次随机选择一个样本计算梯度,按照学习率乘以梯度的方向更新参数,公式为:
θ t + 1 = θ t − η ∇ θ J ( θ t ; x ( i ) , y ( i ) ) \theta_{t+1} = \theta_t - \eta \nabla_\theta J(\theta_t; x^{(i)}, y^{(i)}) θt+1=θtηθJ(θt;x(i),y(i))
特点

  • 引入噪声,可能跳出局部极小
  • 更新波动大
  • 适合在线学习

Python代码

def sgd(x, y, learning_rate=0.01, epochs=1000):
    m = len(y)
    theta = np.random.randn()
    for _ in range(epochs):
        for i in range(m):
            y_pred = theta * x[i]
            gradient = -2 * x[i] * (y[i] - y_pred)
            theta -= learning_rate * gradient
    return theta
# DataLoader设置batch_size=1
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

3. 小批量随机梯度下降(Mini-Batch SGD, MB-SGD)

原理
折中方案,每次用n个样本(典型n=32~512),公式为:
θ t + 1 = θ t − η ⋅ 1 n ∑ i = 1 n ∇ θ J ( θ t ; x ( i ) , y ( i ) ) \theta_{t+1} = \theta_t - \eta \cdot \frac{1}{n}\sum_{i=1}^n \nabla_\theta J(\theta_t; x^{(i)}, y^{(i)}) θt+1=θtηn1i=1nθJ(θt;x(i),y(i))

特点

  • 降低方差,提高计算效率
  • 现代深度学习默认选择

Python代码

def mini_batch_sgd(x, y, batch_size=32, learning_rate=0.01, epochs=1000):
    m = len(y)
    theta = np.random.randn()
    for _ in range(epochs):
        indices = np.random.permutation(m)
        x_shuffled, y_shuffled = x[indices]

相关文章:

  • C++实用技巧之 --- 观察者模式详解
  • 【蓝耘平台与DeepSeek强强联手】:深度探索AI应用实践
  • PDF另存为图片的一个方法
  • 深入解析 PCIe 的 iATU(Internal Address Translation Unit)及其工作原理
  • 从二维到三维3D工业相机如何改变机器视觉检测
  • 《open3d+pyqt》第二章——体素采样-open3d自带显示
  • 微信小程序~电器维修系统小程序
  • 【Pico】使用Pico进行无线串流搜索不到电脑
  • DeepSeekApi对接流式输出异步聊天功能:基于Spring Boot和OkHttp的SSE应用实现
  • Python函数参数参数逐步进阶250214
  • 基于Multi-Runtime的云原生多态微服务:解耦基础设施与业务逻辑的革命性实践
  • 【工业安全】-CVE-2022-35555- Tenda W6路由器 命令注入漏洞
  • PHP防伪溯源查询系统小程序
  • Mysql之主从复制
  • 对接 PayPal 支付平台流程详解
  • 单调栈及相关题解
  • Unity3D 可视化脚本框架设计详解
  • 线程池处理异常
  • 应对DeepSeek总是服务器繁忙的解决方法
  • 服务器linux操作系统安全加固
  • 在哪个网站去租地方做收废站/中国企业500强排行榜
  • 自己做的网站 能收索么/广州网站到首页排名
  • web网站建设报价/广州商务网站建设
  • 做网站单页/指数基金有哪些
  • 网站建设推广页/seo搜索引擎优化
  • 怎么做传奇私服网站/网站建设是什么工作