当前位置: 首页 > news >正文

从代码学习深度强化学习 - Actor-Critic 算法 PyTorch版

文章目录

  • 前言
  • 算法原理
    • 1. 从策略梯度到Actor-Critic
    • 2. Actor 和 Critic 的角色
    • 3. Critic 的学习方式:时序差分 (TD)
    • 4. Actor 的学习方式:策略梯度
    • 5. 算法流程
  • 代码实现
    • 1. 环境与工具函数
    • 2. 构建Actor-Critic智能体
    • 3. 组织训练流程
    • 4. 主程序:启动训练
    • 5. 实验结果
  • 总结


前言

在深度强化学习(DRL)的广阔天地中,算法可以大致分为两大家族:基于价值(Value-based)的算法和基于策略(Policy-based)的算法。像DQN这样的算法通过学习一个价值函数来间接指导策略,而像REINFORCE这样的算法则直接对策略进行参数化和优化。

然而,这两种方法各有优劣。基于价值的方法通常数据效率更高、更稳定,但难以处理连续动作空间;基于策略的方法可以直接处理各种动作空间,并能学习随机策略,但其学习过程往往伴随着高方差,导致训练不稳定、收敛缓慢。

为了融合两者的优点,Actor-Critic(演员-评论家) 框架应运而生。它构成了现代深度强化学习的基石,许多前沿算法(如A2C, A3C, DDPG, TRPO, PPO等)都属于这个大家族。

本文将从理论出发,结合一个完整的 PyTorch 代码实例,带您深入理解基础的 Actor-Critic 算法。我们将通过经典的 CartPole(车杆)环境,一步步构建、训练并评估一个 Actor-Critic 智能体,直观地感受它是如何工作的。

完整代码:下载链接

算法原理

Actor-Critic 算法本质上是一种基于策略的算法,其目标是优化一个带参数的策略。与REINFORCE算法不同的是,它会额外学习一个价值函数,用这个价值函数来“评论”策略的好坏,从而帮助策略函数更好地学习。

1. 从策略梯度到Actor-Critic

在策略梯度方法中,目标函数的梯度可以写成一个通用的形式:

g = E [ ∑ t = 0 T ψ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] g=\mathbb{E}\left[\sum_{t=0}^T\psi_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right] g=E[t=0Tψtθlogπθ(atst)]

其中,ψt 是一个用于评估在状态 st 下采取动作 at 的优劣的标量。ψt 的选择直接影响了算法的性能:

在这里插入图片描述

  • 形式2ψt 是动作 at 之后的所有回报之和。这是 REINFORCE 算法使用的形式。它使用蒙特卡洛方法来估计动作的价值,虽然是无偏估计,但由于包含了从 t 时刻到回合结束的所有随机性,其方差非常大。
  • 形式6ψt时序差分误差(TD Error)。这是本文 Actor-Critic 算法将采用的核心形式。它只利用了一步的真实奖励 r_t 和对下一状态价值的估计 V(s_t+1),极大地降低了方差。

这个转变正是 Actor-Critic 算法的核心思想:不再使用完整的、高方差的轨迹回报,而是引入一个价值函数来提供更稳定、低方差的指导信号。

2. Actor 和 Critic 的角色

我们将 Actor-Critic 算法拆分为两个核心部分:

  • Actor (演员):即策略网络。它的任务是与环境进行交互,并根据 Critic 的“评价”来学习一个更好的策略。它决定了在某个状态下应该采取什么动作。
  • Critic (评论家):即价值网络。它的任务是通过观察 Actor 与环境的交互数据,学习一个价值函数。这个价值函数用于判断在当前状态下,Actor 选择的动作是“好”还是“坏”,从而指导 Actor 的策略更新。

3. Critic 的学习方式:时序差分 (TD)

Critic 的目标是准确地估计状态价值函数 V(s)。它采用**时序差分(Temporal-Difference, TD)**学习方法。具体来说,是TD(0)方法。

在TD学习中,我们希望价值网络的预测值 V(s_t) 能够逼近 TD目标 (TD Target),即 r_t + γV(s_t+1)。因此,Critic 的损失函数定义为两者之间的均方误差:

L ( ω ) = 1 2 ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) 2 \mathcal{L}(\omega)=\frac{1}{2}(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))^2 L(ω)=21(r+γVω(st+1)Vω(st))2

当我们对这个损失函数求梯度以更新 Critic 的网络参数 w 时,有一个非常关键的点:

在TD学习中,目标值 r_t + γV(s_t+1) 被视为一个固定的“标签”(Target),不参与反向传播。因此,梯度只对当前状态的值函数 V(s_t) 求导。

Critic 价值网络表示为 V w V_w Vw,参数为 w w w。价值函数的梯度为:
∇ ω L ( ω ) = − ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) ∇ ω V ω ( s t ) \nabla_\omega\mathcal{L}(\omega)=-(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))\nabla_\omega V_\omega(s_t) ωL(ω)=

相关文章:

  • ubuntu24.04+5090显卡驱动安装踩坑
  • Unity2D 街机风太空射击游戏 学习记录 #12QFramework引入
  • Java 中如何判断一个字符串是否代表一个数值(包括整数、浮点数等)?
  • AI工具在学术写作中的伦理边界与诚信规范的平衡
  • webpack+vite前端构建工具 -6从loader本质看各种语言处理 7webpack处理html
  • RN(React Native)技术应用中常出现的错误及解决办法
  • 《HTTP权威指南》 第11-12章 客户端识别与cookie和基本认证机制
  • Spring Boot 整合 Swagger3 如何生成接口文档?
  • 爬虫入门练习(文字数据的爬取)
  • Typecho博客3D彩色标签云插件(Handsome主题优化版)
  • 编译器优化
  • 445场周赛
  • DeepSeek技术解析:开源大模型的创新突围之路
  • 在esp-idf中发现找不到头文件
  • linux编译安装nginx
  • 药房智慧化升级:最优成本条件下开启地市级医院智慧医疗新变革
  • 【weaviate】分布式数据写入之LSM树深度解析:读写放大的权衡
  • 【力扣 中等 C】983. 最低票价
  • (LeetCode 面试经典 150 题 ) 189. 轮转数组(字符串、双指针)
  • [linux] Ubuntu 24软件下载和安装汇总(自用)
  • 猎趣网站/外贸订单怎样去寻找
  • 生产网线需要什么设备/搜索引擎优化叫什么
  • 虚拟主机网站建设步骤?/seo网站优化收藏
  • 哪个软件可以做明星视频网站/如何制作一个公司网站
  • 神马网站可以做兼职/网络推广代理
  • 什么是网络营销网络营销的内容有哪些/长沙seo推广公司