从代码学习深度强化学习 - Actor-Critic 算法 PyTorch版
文章目录
- 前言
- 算法原理
- 1. 从策略梯度到Actor-Critic
- 2. Actor 和 Critic 的角色
- 3. Critic 的学习方式:时序差分 (TD)
- 4. Actor 的学习方式:策略梯度
- 5. 算法流程
- 代码实现
- 1. 环境与工具函数
- 2. 构建Actor-Critic智能体
- 3. 组织训练流程
- 4. 主程序:启动训练
- 5. 实验结果
- 总结
前言
在深度强化学习(DRL)的广阔天地中,算法可以大致分为两大家族:基于价值(Value-based)的算法和基于策略(Policy-based)的算法。像DQN这样的算法通过学习一个价值函数来间接指导策略,而像REINFORCE这样的算法则直接对策略进行参数化和优化。
然而,这两种方法各有优劣。基于价值的方法通常数据效率更高、更稳定,但难以处理连续动作空间;基于策略的方法可以直接处理各种动作空间,并能学习随机策略,但其学习过程往往伴随着高方差,导致训练不稳定、收敛缓慢。
为了融合两者的优点,Actor-Critic(演员-评论家) 框架应运而生。它构成了现代深度强化学习的基石,许多前沿算法(如A2C, A3C, DDPG, TRPO, PPO等)都属于这个大家族。
本文将从理论出发,结合一个完整的 PyTorch 代码实例,带您深入理解基础的 Actor-Critic 算法。我们将通过经典的 CartPole(车杆)环境,一步步构建、训练并评估一个 Actor-Critic 智能体,直观地感受它是如何工作的。
完整代码:下载链接
算法原理
Actor-Critic 算法本质上是一种基于策略的算法,其目标是优化一个带参数的策略。与REINFORCE算法不同的是,它会额外学习一个价值函数,用这个价值函数来“评论”策略的好坏,从而帮助策略函数更好地学习。
1. 从策略梯度到Actor-Critic
在策略梯度方法中,目标函数的梯度可以写成一个通用的形式:
g = E [ ∑ t = 0 T ψ t ∇ θ log π θ ( a t ∣ s t ) ] g=\mathbb{E}\left[\sum_{t=0}^T\psi_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right] g=E[t=0∑Tψt∇θlogπθ(at∣st)]
其中,ψt
是一个用于评估在状态 st
下采取动作 at
的优劣的标量。ψt
的选择直接影响了算法的性能:
- 形式2:
ψt
是动作at
之后的所有回报之和。这是 REINFORCE 算法使用的形式。它使用蒙特卡洛方法来估计动作的价值,虽然是无偏估计,但由于包含了从t
时刻到回合结束的所有随机性,其方差非常大。 - 形式6:
ψt
是 时序差分误差(TD Error)。这是本文 Actor-Critic 算法将采用的核心形式。它只利用了一步的真实奖励r_t
和对下一状态价值的估计V(s_t+1)
,极大地降低了方差。
这个转变正是 Actor-Critic 算法的核心思想:不再使用完整的、高方差的轨迹回报,而是引入一个价值函数来提供更稳定、低方差的指导信号。
2. Actor 和 Critic 的角色
我们将 Actor-Critic 算法拆分为两个核心部分:
- Actor (演员):即策略网络。它的任务是与环境进行交互,并根据 Critic 的“评价”来学习一个更好的策略。它决定了在某个状态下应该采取什么动作。
- Critic (评论家):即价值网络。它的任务是通过观察 Actor 与环境的交互数据,学习一个价值函数。这个价值函数用于判断在当前状态下,Actor 选择的动作是“好”还是“坏”,从而指导 Actor 的策略更新。
3. Critic 的学习方式:时序差分 (TD)
Critic 的目标是准确地估计状态价值函数 V(s)
。它采用**时序差分(Temporal-Difference, TD)**学习方法。具体来说,是TD(0)方法。
在TD学习中,我们希望价值网络的预测值 V(s_t)
能够逼近 TD目标 (TD Target),即 r_t + γV(s_t+1)
。因此,Critic 的损失函数定义为两者之间的均方误差:
L ( ω ) = 1 2 ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) 2 \mathcal{L}(\omega)=\frac{1}{2}(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))^2 L(ω)=21(r+γVω(st+1)−Vω(st))2
当我们对这个损失函数求梯度以更新 Critic 的网络参数 w
时,有一个非常关键的点:
在TD学习中,目标值
r_t + γV(s_t+1)
被视为一个固定的“标签”(Target),不参与反向传播。因此,梯度只对当前状态的值函数V(s_t)
求导。
Critic 价值网络表示为 V w V_w Vw,参数为 w w w。价值函数的梯度为:
∇ ω L ( ω ) = − ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) ∇ ω V ω ( s t ) \nabla_\omega\mathcal{L}(\omega)=-(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))\nabla_\omega V_\omega(s_t) ∇ωL(ω)=