当前位置: 首页 > news >正文

【Agentic RL 专题】三、深入浅出强化学习算法 TRPO 和PPO

在这里插入图片描述

🧔 这里是九年义务漏网鲨鱼,研究生在读,主要研究方向是人脸伪造检测,长期致力于研究多模态大模型技术;国家奖学金获得者,国家级大创项目一项,发明专利一篇,多篇论文在投,蓝桥杯国家级奖项、妈妈杯一等奖。
✍ 博客主要内容为大模型技术的学习以及相关面经,本人已得到B站、百度、唯品会等多段多模态大模型的实习offer,为了能够紧跟前沿知识,决定写一个“从零学习 RL”主题的专栏。这个专栏将记录我个人的主观学习过程,因此会存在错误,若有出错,欢迎大家在评论区帮助我指出。除此之外,博客内容也会分享一些我在本科期间的一些知识以及项目经验。
🌎 Github仓库地址:Baby Awesome Reinforcement Learning for LLMs and Agentic AI
📩 有兴趣合作的研究者可以联系我:yirongzzz@163.com

深入浅出强化学习算法 TRPO 和PPO

文章目录

  • 深入浅出强化学习算法 TRPO 和PPO
    • 🦈 前言
    • 一、 Policy Optimization
    • 二、信任区域的尝试:TRPO
    • 三、PPO:平衡性能与简易性的王者
    • 四、TRPO vs PPO
    • 五、PPO代码实现
    • 六、专栏面试问题角
    • 七、总结与展望
    • ✍ 参考链接

🦈 前言

在了解完Memory专题后,我们可以了解到智能体中的记忆功能,是需要通过强化学习技术进行动态的环境交互,通过奖励信号学习何时写入、更新、删除或保留记忆,从而更好地支持下游决策和任务完成。例如,在 Memory-R1 框架中,Memory Manager 使用PPO或GRPO来动态调整记忆操作,而 Answer Agent 则通过 Memory Distillation 对 RAG 检索到的信息进行推理和回答。为了更进一步的深入,本专题将从经典的TRPO, PPO强化学习算法开始学习。

一、 Policy Optimization

在强化学习中,主要可以分为两种学习方法,一种是 Value-based 的方法 (Q-learning),此类方法是通过学习一个值函数来间接决定动作的。 Value-based 的方法在许多简单的问题上都失败了,并且很难理解,并不是Agentic RL中所采用的强化学习方式。 另一种是Policy-based方法 (TRPO, PPO),是直接学习一个策略函数 π(a∣s;θ)π(a|s; θ)π(as;θ),表示在参数为θ\thetaθ, 状态为s的情况下选择行动a的概率。此类方法时通过优化参数来学习。但由于\theta涉及到了环境的随机性,梯度无法直接计算,因此还需要找到一个estimator来计算这个梯度:
g=Et​[∇θlogπθ​(at∣st)At]g = E^t​[∇_θ logπ_θ​(a_t∣s_t)A_t] g=Et[θlogπθ(atst)At]
在这里插入图片描述

🧠 AtA_tAt 是什么,为什么要估计梯度时要乘上他?

因为如果是简单的求梯度,只是优化了增加当前状态选择该动作的概率。动作是有好有坏的,需要告诉他参数应该往哪边偏,因此需要有一个优势函数AtA_tAt来辅助参数的优化,告诉模型当前的动作是一个好动作还是一个坏的动作,好动作我们希望奖励,坏动作则对应惩罚:
At=Q(st,at)A_t=Q(s_t,a_t)At=Q(st,at),−V(st)≈Rt−V(st), −V(s_t)≈R_t−V(s_t),V(st)RtV(st)

🧠 为什么不能直接用 RtR_tRt (未来的总回报) 替换 AtA_tAt

可以,但效果不好。log⁡πθ(at∣st)\log\pi_\theta(a_t|s_t)logπθ(atst) 项的意义是“增加(或减少)在状态 sts_tst 时选择 ata_tat 的概率”。

  1. 如果我们直接乘以 RtR_tRt(总回报),RtR_tRt 几乎总是正的(比如游戏得分)。这意味着模型会“增加所有动作的概率”,只是“好动作”增加得多,“坏动作”增加得少,收敛很慢。
  2. 我们需要的是一个“基线 (Baseline)”。动作的好坏是相对的。

🧠 RL的采样成本高,为什么不能对同一批数据进行反复更新?

这是由于在强化学习中,我们首先采用一个旧策略去得到一批数据,这批数据反应的是旧策略的状态-动作数据分布,如果反复更新,会导致新策略已经和旧策略差别很大了,用它去更新新策略,这个估计可能严重偏离真实梯度,从而导致会导致训练崩塌。

二、信任区域的尝试:TRPO

为了解决 On-Policy 采样效率低的问题,TRPO (Trust Region Policy Optimization) 在 2015 年被提出。TRPO属于早期的Policy-based方法,TRPO 的核心思想是:我们可以用旧数据(πold\pi_{old}πold)做多步更新,但前提是必须给新策略(πθ\pi_\thetaπθ)一个“信任区域 (Trust Region)”。
TRPO 通过 trust region 明确约束 KL 散度,同样保证梯度估计在安全范围内,通过约束新旧策略的KL散度来避免策略崩塌,通过优化目标:
在这里插入图片描述

🧠 但是该方法无法适用于复杂的架构,为什么不适用?

Dropout层会使得梯度带有高方差,从而破坏了KL散度约束的精确计算。TRPO 的 KL 散度约束是基于“单一策略参数”的,而共享参数会让 KL 计算更复杂、不再严格符合理论前提。

在这里插入图片描述

三、PPO:平衡性能与简易性的王者

PPO 在 2017 出来后,成为最常用的 policy optimization 算法,因为它又稳又简单,并且在instructGPT中的RLHF微调,也是采用了PPO来优化模型的。PPO与TRPO 类似,PPO提出了Clipped Surrogate Objective限制新旧策略差距, 允许在同一批数据上做多步更新。相较于KL散度约束,PPO采用了硬约束的方式:

在这里插入图片描述

我们来拆解这个公式:

  • rt(θ)r_t(\theta)rt(θ):这是新旧策略的概率比 (Probability Ratio),其中:rt(θ)=πθ(at∣st)πθold(at∣st)r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}rt(θ)=πθold(atst)πθ(atst)
    • rt(θ)>1r_t(\theta) > 1rt(θ)>1:新策略更倾向于做 a_t。
    • rt(θ)r_t(\theta)rt(θ) < 1:新策略更不倾向于做 a_t。
  • clip(rt,1−ϵ,1+ϵ)\text{clip}(r_t, 1-\epsilon, 1+\epsilon)clip(rt,1ϵ,1+ϵ):这就是 PPO 的灵魂。ϵ\epsilonϵ 是一个超参数(通常为 0.2)。它强行把 rtr_trt 裁剪到 [1−ϵ,1+ϵ][1-\epsilon, 1+\epsilon][1ϵ,1+ϵ](即 [0.8,1.2][0.8, 1.2][0.8,1.2])的区间内。

在这里插入图片描述

为了防止训练不稳定,除了LCLIPL^{CLIP}LCLIP损失外,PPO还引入了value function lossentropy bonus:

  • value function loss : 既要更新策略,又要保证 value function 能准确预测状态价值;
  • entropy bonus: 如果策略损失收敛过快,会导致数据的多样性下降,只能学习到局部最优;

除此之外,PPO中也可以通过KL散度进行约束: maximizeEt[πθπoldAt−βKL[πold,πθ]]\text{maximize} E^t[\frac{π_θ}{π_{old}}A^t−βKL[π_{old},π_θ]]maximizeEt[πoldπθAtβKL[πold,πθ]],但是不同问题需要不同β,同一问题中,随着训练进程,最优 β 也会变化。因此,还需要通过自适应的β来调整。

四、TRPO vs PPO

特性TRPOPPO
控制策略更新hard constraint:KL divergence ≤ δsoft constraint:clipping 或 KL penalty
数学形式非线性约束,需要二阶信息(Fisher矩阵)一阶可直接在 surrogate objective 上优化
优化方法共轭梯度 + 二阶近似标准 SGD / Adam,一阶优化即可

在这里插入图片描述

五、PPO代码实现

理论部分掌握了,就需要开始实战演练了。这部分代码参考了 nikhilbarhate99/PPO-PyTorch (GitHub),这是一个非常清晰的实现。我们将实现一个 Actor-Critic (A2C) 架构的 PPO。

  • 定义 Actor 和 critic 网络,其中 act()函数主要用于记录旧决策的action以及对应的对数概率, evaluate()函数主要用于计算新决策下旧action的对数概率。因此首先是需要通过旧决策前馈得到旧决策下的action,action_logprob以及state_val
class ActorCritic(nn.Module):def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init):super(ActorCritic, self).__init__()self.has_continuous_action_space = has_continuous_action_spaceif has_continuous_action_space:self.action_dim = action_dimself.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device)# actorif has_continuous_action_space :self.actor = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, action_dim),nn.Tanh())else:self.actor = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, action_dim),nn.Softmax(dim=-1))# criticself.critic = nn.Sequential(nn.Linear(state_dim, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, 1))def set_action_std(self, new_action_std):if self.has_continuous_action_space:self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device)else:print("--------------------------------------------------------------------------------------------")print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy")print("--------------------------------------------------------------------------------------------")def forward(self):raise NotImplementedErrordef act(self, state):if self.has_continuous_action_space:action_mean = self.actor(state)cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)dist = MultivariateNormal(action_mean, cov_mat)else:action_probs = self.actor(state)dist = Categorical(action_probs)action = dist.sample()action_logprob = dist.log_prob(action)state_val = self.critic(state)return action.detach(), action_logprob.detach(), state_val.detach()def evaluate(self, state, action):if self.has_continuous_action_space:action_mean = self.actor(state)action_var = self.action_var.expand_as(action_mean)cov_mat = torch.diag_embed(action_var).to(device)dist = MultivariateNormal(action_mean, cov_mat)# For Single Action Environments.if self.action_dim == 1:action = action.reshape(-1, self.action_dim)else:action_probs = self.actor(state)dist = Categorical(action_probs)action_logprobs = dist.log_prob(action)dist_entropy = dist.entropy()state_values = self.critic(state)return action_logprobs, state_values, dist_entropy
  • PPO代码
class PPO:def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6):self.has_continuous_action_space = has_continuous_action_spaceif has_continuous_action_space:self.action_std = action_std_initself.gamma = gammaself.eps_clip = eps_clipself.K_epochs = K_epochsself.buffer = RolloutBuffer()self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)self.optimizer = torch.optim.Adam([{'params': self.policy.actor.parameters(), 'lr': lr_actor},{'params': self.policy.critic.parameters(), 'lr': lr_critic}])self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device)self.policy_old.load_state_dict(self.policy.state_dict())self.MseLoss = nn.MSELoss()def set_action_std(self, new_action_std):if self.has_continuous_action_space:self.action_std = new_action_stdself.policy.set_action_std(new_action_std)self.policy_old.set_action_std(new_action_std)else:print("--------------------------------------------------------------------------------------------")print("WARNING : Calling PPO::set_action_std() on discrete action space policy")print("--------------------------------------------------------------------------------------------")def decay_action_std(self, action_std_decay_rate, min_action_std):print("--------------------------------------------------------------------------------------------")if self.has_continuous_action_space:self.action_std = self.action_std - action_std_decay_rateself.action_std = round(self.action_std, 4)if (self.action_std <= min_action_std):self.action_std = min_action_stdprint("setting actor output action_std to min_action_std : ", self.action_std)else:print("setting actor output action_std to : ", self.action_std)self.set_action_std(self.action_std)else:print("WARNING : Calling PPO::decay_action_std() on discrete action space policy")print("--------------------------------------------------------------------------------------------")def select_action(self, state):if self.has_continuous_action_space:with torch.no_grad():state = torch.FloatTensor(state).to(device)action, action_logprob, state_val = self.policy_old.act(state)self.buffer.states.append(state)self.buffer.actions.append(action)self.buffer.logprobs.append(action_logprob)self.buffer.state_values.append(state_val)return action.detach().cpu().numpy().flatten()else:with torch.no_grad():state = torch.FloatTensor(state).to(device)action, action_logprob, state_val = self.policy_old.act(state)self.buffer.states.append(state)self.buffer.actions.append(action)self.buffer.logprobs.append(action_logprob)self.buffer.state_values.append(state_val)return action.item()def update(self):# Monte Carlo estimate of returnsrewards = []discounted_reward = 0for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)):if is_terminal:discounted_reward = 0discounted_reward = reward + (self.gamma * discounted_reward)rewards.insert(0, discounted_reward)# Normalizing the rewardsrewards = torch.tensor(rewards, dtype=torch.float32).to(device)rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)# convert list to tensorold_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device)old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device)old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device)old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device)# calculate advantagesadvantages = rewards.detach() - old_state_values.detach()# Optimize policy for K epochsfor _ in range(self.K_epochs):# Evaluating old actions and valueslogprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)# match state_values tensor dimensions with rewards tensorstate_values = torch.squeeze(state_values)# Finding the ratio (pi_theta / pi_theta__old)ratios = torch.exp(logprobs - old_logprobs.detach())# Finding Surrogate Loss  surr1 = ratios * advantagessurr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages# final loss of clipped objective PPOloss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy# take gradient stepself.optimizer.zero_grad()loss.mean().backward()self.optimizer.step()# Copy new weights into old policyself.policy_old.load_state_dict(self.policy.state_dict())# clear bufferself.buffer.clear()def save(self, checkpoint_path):torch.save(self.policy_old.state_dict(), checkpoint_path)def load(self, checkpoint_path):self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

六、专栏面试问题角

在学习 Agentic RL 的过程中,PPO 是面试中必考的基础。

❓ PPO 是 On-Policy 还是 Off-Policy 算法?
PPO 本质上是 On-Policy 算法,因为它使用当前策略(或上一个旧策略 \pi_{old})采样数据。但是,它通过“概率比”和“裁剪”技术,引入了 Off-Policy 的思想,允许使用同一批数据进行多步(K-Epochs)更新,极大地提高了采样效率,使其成为 On-Policy 算法中最高效的代表。

❓ PPO 中的 Actor-Critic 架构,Actor 和 Critic 必须共享参数吗?
不是必须的,但推荐共享。

  • 共享参数(如代码实现中,共享底层的 nn.Linear(64, 64)):Critic 可以“偷听”到 Actor 的特征表示,使得 V(s_t) 的估算更准确,训练更稳定。

  • 不共享参数(两个完全独立的网络):实现更简单,但 Critic 可能需要更多数据才能学好 V(s_t)。TRPO 算法就(因为其二阶计算的限制)难以处理参数共享。

❓代码中 act() 和 evaluate() 两个函数的根本区别是什么?

  • act(state):由 policy_old 调用。它的目的是**“与环境交互并采样”。它输出一个具体的动作 action** 和这个动作的 logprob,用于存入 Buffer。

  • evaluate(state, action):由 policy (新策略) 调用。它的目的是**“计算损失”。它接收 Buffer 中“旧的” stateaction,然后计算“新策略”下,做出这个“旧动作”**的对数概率 logprobs 和熵 dist_entropy,用于计算概率比 ratios

❓ 如果我把 PPO 的 K_epochs 设得非常大(比如 100),会发生什么?

这会破坏 PPO 的假设。PPO 的 LCLIPL^{CLIP}LCLIP 约束是基于 πθ和πold\pi_\theta 和\pi_{old}πθπold 差别不大的前提下才近似有效的。
如果 K 太大,πθ\pi_\thetaπθ 在这 100 个 epoch 中会和 πold\pi_{old}πold 差别非常大LCLIPL^{CLIP}LCLIP 提供的“信任区域”会失效,梯度估计将不再准确,数据分布会“漂移 (Drift)”,训练很可能会崩溃

七、总结与展望

在这一章中,我们从 Agentic RL 的需求出发,深入了 Policy-based 方法。我们理解了 PPO 是为了解决传统 Policy Gradient 采样效率低和 TRPO 计算复杂度高而诞生的。

  • PPO 原理:通过概率比 rt(θ)r_t(\theta)rt(θ) 和裁剪 clip\text{clip}clip(…),在一阶优化中实现了“信任区域”,允许 On-Policy 算法进行多步更新。
  • PPO 实战:我们通过 PyTorch 代码,彻底厘清了 LCLIP、LVF和S^{CLIP}、 L^{VF} 和 SCLIPLVFS(熵)这“三位一体”的损失函数是如何在 update() 函数中实现的。

下一步展望:我们有了引擎 (PPO),但这个引擎需要“燃料”(Reward)。在 LLM 中,Reward 从哪里来?这就是我们专栏下一篇要深入的主题:RLHF 与 RM (Reward Model),即 InstructGPT 如何使用 PPO 来“对齐”人类偏好。

✍ 参考链接

  1. PPO 原始论文 (必读):Schulman, J., et al. (2017). Proximal Policy Optimization Algorithms. arXiv:1707.06347
  2. TRPO 原始论文 (选读):Schulman, J., et al. (2015). Trust Region Policy Optimization. arXiv:1502.05477
  3. 代码实现参考 (GitHub):nikhilbarhate99/PPO-PyTorch
http://www.dtcms.com/a/561035.html

相关文章:

  • 中国最好的建站公司毕业设计模板
  • 《算法通关指南:数据结构和算法篇 --- 栈相关算法题》--- 1. 【模板】栈,2.有效的括号
  • 高效管理搜索历史:Vue持久化实践
  • html网站架设目录和文章wordpress
  • Rust 编程语言基础知识全面介绍
  • 洛龙区网站制作建设费用做网站一般用什么语言
  • 计算机网络---基础诊断ping
  • 13.2.2.Nginx
  • java后端学习经验分享(大三进大厂版)
  • 好用的镜像源
  • 做网站的经验有什么好的加盟店项目
  • linux-shell-基础与变量和运算符-1
  • 论文解读:Sleeping with One Eye Open: Fast, Sustainable Storage with Sandman
  • 手机客户端网站建设腾讯云服务器免费领取
  • Gorm(十三)主从表的判断
  • 从零开始的云原生之旅(十):HPA 完全指南:从原理到实践
  • 注册网站费属于什么费用模板公司
  • MYSQL-多种方法安装部署
  • 做网站要学哪些代码上海资本公司排名
  • 认识多线程:单例模式
  • 深入解析 HarmonyOS 中 NavDestination 导航目标页的生命周期
  • 3、webgl 基本概念 + 绘制线段 + 绘制三角形
  • 【LeetCode热题100(58/100)】单词搜索
  • 旅行社网站模版网页设计六安模板
  • 求解器驱动智能决策新纪元
  • 简单网站制作成品广东省广州市佛山市
  • 使用 TransGPTex 将 LaTeX 英文论文翻译成中文:完整实战教程
  • APIJSON:用JSON自动生成API,告别手写CRUD!【.NET 8 集成案例,也支持JAVA】
  • 网络版本计算器
  • 视频容器(如 MP4)的作用:组织与同步