当前位置: 首页 > news >正文

传统策略梯度方法的弊端与PPO的改进:稳定性与样本效率的提升

为什么传统策略梯度方法(如REINFORCE算法)在训练过程中存在不稳定性和样本效率低下的问题

1. 传统策略梯度方法的基本公式

传统策略梯度方法的目标是最大化累积奖励的期望值。具体来说,优化目标可以表示为:
max ⁡ θ J ( θ ) = E π [ ∑ t = 0 ∞ γ t R t + 1 ] \max_\theta J(\theta) = \mathbb{E}_\pi\left[\sum_{t=0}^{\infty} \gamma^t R_{t+1}\right] θmaxJ(θ)=Eπ[t=0γtRt+1]
其中:

  • J ( θ ) J(\theta) J(θ) 是策略性能,即累积奖励的期望值。
  • π θ ( a t ∣ s t ) \pi_\theta(a_t|s_t) πθ(atst) 是在策略 π \pi π 下,状态 s t s_t st 下选择动作 a t a_t at 的概率。
  • R t + 1 R_{t+1} Rt+1 是在时间步 t + 1 t+1 t+1 获得的奖励。
  • γ \gamma γ 是折扣因子,用于衡量未来奖励的当前价值。

为了实现这个目标,策略梯度定理提供了策略性能的梯度的解析表达式:
∇ θ J ( θ ) = E π [ ∑ t = 0 ∞ γ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ⋅ G t ] \nabla_\theta J(\theta) = \mathbb{E}_\pi\left[\sum_{t=0}^{\infty} \gamma^t \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot G_t\right] θJ(θ)=Eπ[t=0γtθlogπθ(atst)Gt]
其中:

  • G t G_t Gt 是从时间步 t t t 开始的累积奖励:

G t = ∑ k = t ∞ γ k − t R k + 1 G_t = \sum_{k=t}^{\infty} \gamma^{k-t} R_{k+1} Gt=k=tγktRk+1

2. 不稳定性问题

(1)梯度估计的高方差

传统策略梯度方法(如REINFORCE算法)直接使用采样轨迹来估计策略梯度。具体更新规则为:
θ ← θ + α ∑ t = 0 T ∇ θ log ⁡ π θ ( a t ∣ s t ) ⋅ G t \theta \leftarrow \theta + \alpha \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot G_t θθ+αt=0Tθlogπθ(atst)Gt
其中:

  • α \alpha α 是学习率。
  • T T T 是轨迹的长度。

问题分析:

  • 高方差:累积奖励 G t G_t Gt 是一个随机变量,其值取决于具体的采样轨迹。由于环境的随机性和策略的随机性,不同轨迹的累积奖励 G t G_t Gt 可能差异很大,导致梯度估计的方差很高。高方差的梯度估计使得训练过程不稳定,容易出现剧烈波动。
  • 更新过大:由于梯度估计的方差很高,每次更新可能会导致策略参数 θ \theta θ 发生较大变化。这种过大的更新可能会使策略偏离最优策略,导致训练过程不稳定。

3. 样本效率低下的问题

(1)单次更新

传统策略梯度方法通常在每个数据批次上只进行一次更新。具体来说,每采样一条轨迹,就计算一次梯度并更新策略参数。这种单次更新的方式导致样本的利用效率较低。

问题分析:

  • 样本利用率低:每个数据批次只使用一次,更新后就丢弃。这意味着每个样本只对策略更新贡献一次,没有充分利用样本的信息。
  • 数据冗余:在复杂环境中,采样到的轨迹可能包含大量重复或相似的状态和动作,这些冗余数据没有被充分利用,导致样本效率低下。

4. PPO如何解决这些问题

PPO(Proximal Policy Optimization)通过以下两种主要机制解决了传统策略梯度方法的不稳定性和样本效率低下的问题:

(1)剪切机制(Clipping Mechanism)

PPO引入了一个剪切的目标函数,限制新策略与旧策略之间的概率比率。具体来说,PPO的目标函数为:
L C L I P ( θ ) = E t [ min ⁡ ( r t ( θ ) A t , clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A t ) ] L^{CLIP}(\theta) = \mathbb{E}_t\left[\min\left(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t\right)\right] LCLIP(θ)=Et[min(rt(θ)At,clip(rt(θ),1ϵ,1+ϵ)At)]
其中:

  • r t ( θ ) = π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} rt(θ)=πθold(atst)πθ(atst) 是新策略与旧策略的概率比率。
  • A t A_t At 是优势函数,表示在状态 s t s_t st 下采取动作 a t a_t at 的相对优势。
  • ϵ \epsilon ϵ 是一个超参数,通常取值为0.1或0.2。

解决不稳定性的机制:

  • 限制更新幅度:通过剪切操作 clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) clip(rt(θ),1ϵ,1+ϵ),PPO限制了新策略与旧策略之间的概率比率,防止策略更新过大。这使得每次更新更加平滑,减少了训练过程中的波动。
  • 降低方差:剪切机制通过限制概率比率的范围,减少了梯度估计的方差,使得训练过程更加稳定。

(2)多次更新

PPO允许在一个数据批次上进行多次更新。具体来说,PPO在采样一批数据后,可以多次使用这些数据进行更新,直到策略收敛。

解决样本效率低下的机制:

  • 提高样本利用率:通过多次更新,每个样本可以多次贡献于策略的优化,充分利用了样本的信息。
  • 减少冗余:多次更新可以更好地利用数据中的信息,减少冗余数据对训练的影响,提高样本的利用效率。

5. 总结

传统策略梯度方法(如REINFORCE算法)在训练过程中存在不稳定性和样本效率低下的问题,主要原因是:

  1. 高方差的梯度估计:累积奖励 G t G_t Gt 的随机性导致梯度估计的方差很高,使得训练过程不稳定。
  2. 单次更新:每个数据批次只使用一次,更新后就丢弃,导致样本的利用效率较低。

PPO通过引入剪切机制和多次更新,解决了这些问题:

  1. 剪切机制:限制新策略与旧策略之间的概率比率,防止策略更新过大,降低梯度估计的方差,提高训练的稳定性。
  2. 多次更新:在一个数据批次上进行多次更新,充分利用样本信息,提高样本的利用效率。

这些改进使得PPO在训练过程中更加稳定,样本效率更高,成为强化学习领域中一种常用的基准算法。


文章转载自:
http://cajan.wsgyq.cn
http://ablegate.wsgyq.cn
http://adventive.wsgyq.cn
http://amniocentesis.wsgyq.cn
http://anaerobiosis.wsgyq.cn
http://appulsively.wsgyq.cn
http://capulet.wsgyq.cn
http://bari.wsgyq.cn
http://aiee.wsgyq.cn
http://buns.wsgyq.cn
http://barilla.wsgyq.cn
http://approved.wsgyq.cn
http://autofining.wsgyq.cn
http://bibelot.wsgyq.cn
http://basseterre.wsgyq.cn
http://chromophotograph.wsgyq.cn
http://camise.wsgyq.cn
http://bushmanoid.wsgyq.cn
http://angiosarcoma.wsgyq.cn
http://annunciator.wsgyq.cn
http://anta.wsgyq.cn
http://append.wsgyq.cn
http://cementite.wsgyq.cn
http://amphipathic.wsgyq.cn
http://alexandrite.wsgyq.cn
http://bootblack.wsgyq.cn
http://chanukah.wsgyq.cn
http://autorotate.wsgyq.cn
http://camerist.wsgyq.cn
http://backlining.wsgyq.cn
http://www.dtcms.com/a/100275.html

相关文章:

  • 【干货】前端实现文件保存总结
  • rce操作
  • 唤起“堆”的回忆
  • 基于自定义注解+反射+AOP+Redis的通用开关设计:在投行交易与风控系统的落地实践
  • golang 的reflect包的常用方法
  • 低速通信之王:LIN总线工作原理入门
  • 创作领域“<em >彩</em><em>票</em><em>导</em><em>师</em><em>带</em><em>玩</em><em>群
  • SvelteKit 最新中文文档教程(15)—— 链接选项
  • C语言的sprintf函数使用
  • Rust 为什么不适合开发 GUI
  • Java后端开发: 如何安装搭建Java开发环境《安装JDK》和 检测JDK版本
  • 【Tauri2】008——简单说说配置文件
  • QtWebApp使用
  • .Net framework 3.5怎样离线安装
  • Redis-09.Redis常用命令-通用命令
  • Python练习
  • QXmpp入门
  • 前端学习日记--JavaScript
  • 大模型生成吉卜力风格艺术:技术与魔法的完美结合
  • 【附JS、Python、C++题解】Leetcode面试150题(12)多数问题
  • Nginx — nginx.pid打开失败及失效的解决方案
  • css基础之浮动相关学习
  • 实现一个简易版的前端监控 SDK
  • ​AI训练中的专有名词大白话版
  • Linux《进程概念(上)》
  • PGD对抗样本生成算法实现(pytorch版)
  • React编程模型:React Streams规范详解
  • 阿里:多模态大模型预训练数据治理
  • VBA第三十四期 VBA中怎么用OnKey事件
  • Java与代码审计-Java基础语法