当前位置: 首页 > news >正文

大模型RLHF中PPO强化学习代码学习笔记(二)

大模型RLHF中PPO强化学习代码学习笔记(二)

      • 一、符号约定
      • 二、逐行公式解析
        • **1. 生成回复:模拟策略与环境交互**
        • **2. 拼接数据:构建状态-动作对**
        • **3. 奖励模型打分:初步评价动作质量**
        • **4. 策略与价值预测(当前模型)**
        • **5. 策略与价值预测(参考模型)**
        • **6. 计算KL散度:惩罚策略突变**
        • **7. 最终奖励:结合KL惩罚**
        • **8. 优势函数:衡量策略改进空间**
        • **9. 回报计算:价值网络的学习目标**
        • **10. 多轮优化循环:稳定更新策略与价值**
        • **11. 重新计算策略与价值(更新后模型)**
        • **12. 价值损失:训练价值网络更准确**
        • **13. 策略比例:衡量新旧策略差异**
        • **14. PPO策略损失:稳定更新策略**
        • **15. 总损失:平衡策略与价值优化**
        • **16-18. 梯度更新:优化模型参数**
      • 三、核心逻辑总结

PPO伪代码

为了更详细的描述PPO伪代码中详细的数学过程,又进行了如下整理,图片来自 PPO算法讲解 。以下是PPO算法伪代码中 每一步的具体公式 、符号定义及作用解析,结合强化学习和大模型训练的场景:

一、符号约定

  • s s s:状态(如batch_prompt,即模型的输入prompt)
  • a a a:动作(如batch_response,即模型生成的回复)
  • π θ \pi_\theta πθ策略网络(active_model),输出动作分布和价值预测
  • π old \pi_{\text{old}} πold参考策略(ref_model,通常是更新前的 π θ \pi_\theta πθ,用于稳定训练)
  • V θ V_\theta Vθ价值网络(与 π θ \pi_\theta πθ共享参数,预测状态价值)
  • r ( s , a ) r(s,a) r(s,a):奖励模型打分(如batch_scores,衡量 s → a s \to a sa的质量)
  • KL ( p ∥ q ) \text{KL}(p\|q) KL(pq):KL散度,衡量两个分布 p p p q q q的差异

二、逐行公式解析

1. 生成回复:模拟策略与环境交互
batch_response = active_model.generate(batch_prompt)  
  • 公式 a ∼ π θ ( ⋅ ∣ s ) a \sim \pi_\theta(\cdot|s) aπθ(s) s = batch_prompt s=\text{batch\_prompt} s=batch_prompt,模型根据状态 s s s采样生成动作 a a a
  • 作用:获取模型的“行为样本”(prompt→response对),模拟强化学习中“智能体与环境交互”的过程。
2. 拼接数据:构建状态-动作对
batch_data = concat(batch_prompt, batch_response)  
  • 公式 batch_data = ( s , a ) \text{batch\_data} = (s, a) batch_data=(s,a)(拼接状态 s s s和动作 a a a,作为后续模型输入)
  • 作用:将“输入-输出”封装为训练单元,供策略、价值、奖励模型处理。
3. 奖励模型打分:初步评价动作质量
batch_scores = reward_model(batch_data)  
  • 公式 r = r ( s , a ) r = r(s, a) r=r(s,a)(奖励模型对状态-动作对 ( s , a ) (s,a) (s,a)打分,如内容质量、用户偏好)
  • 作用:量化动作 a a a的“内在价值”,是强化学习奖励的核心来源。
4. 策略与价值预测(当前模型)
batch_all_probs, batch_probs, batch_all_values = active_model.forward_pass(batch_data)  
  • 公式
    • 策略分布: π θ ( ⋅ ∣ s ) = batch_all_probs \pi_\theta(\cdot|s) = \text{batch\_all\_probs} πθ(s)=batch_all_probs(状态 s s s下,所有可能动作的概率分布)
    • 动作概率: π θ ( a ∣ s ) = batch_probs \pi_\theta(a|s) = \text{batch\_probs} πθ(as)=batch_probs(生成当前动作 a a a的概率)
    • 价值预测: V θ ( s ) = batch_all_values V_\theta(s) = \text{batch\_all\_values} Vθ(s)=batch_all_values(预测状态 s s s的期望回报)
  • 作用:获取策略的概率分布(用于后续KL计算和策略更新)和价值预测(用于优势计算)。
5. 策略与价值预测(参考模型)
ref_all_probs, ref_probs, ref_all_values = ref_model.forward_pass(batch_data)  
  • 公式
    • 参考策略分布: π old ( ⋅ ∣ s ) = ref_all_probs \pi_{\text{old}}(\cdot|s) = \text{ref\_all\_probs} πold(s)=ref_all_probs
    • 参考动作概率: π old ( a ∣ s ) = ref_probs \pi_{\text{old}}(a|s) = \text{ref\_probs} πold(as)=ref_probs
  • 作用:提供“基准分布”,通过KL散度惩罚策略的突变(如生成风格突然变化)。
6. 计算KL散度:惩罚策略突变
kls = compute_KL(batch_all_probs, ref_all_probs)  
  • 公式
    KL ( π θ ∥ π old ) = ∑ x π θ ( x ∣ s ) log ⁡ π θ ( x ∣ s ) π old ( x ∣ s ) \text{KL}(\pi_\theta \parallel \pi_{\text{old}}) = \sum_{x} \pi_\theta(x|s) \log\frac{\pi_\theta(x|s)}{\pi_{\text{old}}(x|s)} KL(πθπold)=xπθ(xs)logπold(xs)πθ(xs)
    (对所有可能的动作 x x x,计算分布差异)
  • 作用:避免策略更新过于剧烈(如生成结果突然从“正式”变“口语化”),保证训练稳定。
7. 最终奖励:结合KL惩罚
rewards = compute_rewards(batch_scores, kls)  
  • 公式 r final = r − α ⋅ kls r_{\text{final}} = r - \alpha \cdot \text{kls} rfinal=rαkls α \alpha α为KL惩罚系数,平衡“高奖励”和“策略稳定”)
  • 作用:修正原始奖励,让模型既追求高得分,又不“乱变风格”。
8. 优势函数:衡量策略改进空间
advantages = compute_advantages(batch_all_values, rewards)  
  • 公式 Adv = r final − V θ ( s ) \text{Adv} = r_{\text{final}} - V_\theta(s) Adv=rfinalVθ(s)(实际回报与价值预测的差,回答“当前策略比预期好多少?”)
  • 作用:指导策略更新——若 Adv > 0 \text{Adv}>0 Adv>0,强化当前动作;若 Adv < 0 \text{Adv}<0 Adv<0,弱化当前动作。
9. 回报计算:价值网络的学习目标
returns = advantages + batch_all_values  
  • 公式 Returns = V θ ( s ) + Adv = r final \text{Returns} = V_\theta(s) + \text{Adv} = r_{\text{final}} Returns=Vθ(s)+Adv=rfinal(等价于实际奖励,是价值网络的“理想预测值”)
  • 作用:训练价值网络更准确预测状态 s s s的期望回报。
10. 多轮优化循环:稳定更新策略与价值
for i in range(epoch):  
  • 逻辑:在同一批数据上迭代训练,让模型充分学习,同时通过**策略比例(ratio)**限制更新幅度(避免过拟合)。
11. 重新计算策略与价值(更新后模型)
active_all_probs, active_probs, active_all_values = active_model.forward_pass(batch_data)  
  • 公式
    • 新策略分布: π θ ′ ( ⋅ ∣ s ) = active_all_probs \pi_\theta'(\cdot|s) = \text{active\_all\_probs} πθ(s)=active_all_probs(参数更新后,策略的动作分布)
    • 新动作概率: π θ ′ ( a ∣ s ) = active_probs \pi_\theta'(a|s) = \text{active\_probs} πθ(as)=active_probs(更新后生成 a a a的概率)
    • 新价值预测: V θ ′ ( s ) = active_all_values V_\theta'(s) = \text{active\_all\_values} Vθ(s)=active_all_values(更新后价值网络的预测)
  • 作用:获取更新后的策略和价值预测,用于计算损失。
12. 价值损失:训练价值网络更准确
loss_state_value = torch.mean((returns - active_all_values) ** 2)  
  • 公式
    L value = 1 N ∑ i = 1 N ( Returns i − V θ ′ ( s i ) ) 2 \mathcal{L}_{\text{value}} = \frac{1}{N} \sum_{i=1}^N \left( \text{Returns}_i - V_\theta'(s_i) \right)^2 Lvalue=N1i=1N(ReturnsiVθ(si))2
    (均方误差,最小化预测值与回报的差异)
  • 作用:让价值网络更精准地预估“做这个动作能得多少分”。
13. 策略比例:衡量新旧策略差异
ratio = active_probs / batch_probs  
  • 公式 ρ = π θ ′ ( a ∣ s ) π θ ( a ∣ s ) \rho = \frac{\pi_\theta'(a|s)}{\pi_\theta(a|s)} ρ=πθ(as)πθ(as)(新策略生成 a a a的概率 ÷ 旧策略生成 a a a的概率)
  • 作用:判断策略变化幅度—— ρ > 1 \rho>1 ρ>1表示新策略更倾向于 a a a ρ < 1 \rho<1 ρ<1则相反。
14. PPO策略损失:稳定更新策略
loss_ppo = torch.mean(-advantages * ratio)  
  • 公式(简化版,实际含clip):
    L ppo = − 1 N ∑ i = 1 N Adv i ⋅ ρ i \mathcal{L}_{\text{ppo}} = -\frac{1}{N} \sum_{i=1}^N \text{Adv}_i \cdot \rho_i Lppo=N1i=1NAdviρi
    (取负号,将“最大化优势×比例”转为梯度下降的“最小化问题”;实际工程会用clip(\rho, 1-\epsilon, 1+\epsilon)限制 ρ \rho ρ,避免更新过猛
  • 作用
    • Adv > 0 \text{Adv}>0 Adv>0(动作比预期好):鼓励 ρ \rho ρ大(新策略更倾向于 a a a)。
    • Adv < 0 \text{Adv}<0 Adv<0(动作比预期差):鼓励 ρ \rho ρ小(新策略更规避 a a a)。
15. 总损失:平衡策略与价值优化
loss = loss_ppo + value_loss_rate * loss_state_value  
  • 公式 L total = L ppo + λ ⋅ L value \mathcal{L}_{\text{total}} = \mathcal{L}_{\text{ppo}} + \lambda \cdot \mathcal{L}_{\text{value}} Ltotal=Lppo+λLvalue λ \lambda λvalue_loss_rate,平衡两者权重)
  • 作用:同时优化策略网络(生成更优回复)和价值网络(预测更准)。
16-18. 梯度更新:优化模型参数
loss.backward()    # 计算梯度:∇θℒ_total  
optimizer_step()   # 更新参数:θ ← θ - η·∇θℒ_total(η为学习率)  
optimizer_zero_grad()  # 清空梯度:∇θ ← 0  
  • 作用:通过反向传播和梯度下降,逐步优化模型参数,让策略更优、价值预测更准。

三、核心逻辑总结

PPO的训练像一场 “策略微调游戏”

  1. 采样试错:模型生成回复,拿奖励(加KL惩罚)。
  2. 自我评估:预测自己能得多少分(价值网络),对比实际得分算“优势”。
  3. 稳定改进:通过“新旧策略比例”限制更新幅度,避免“画风突变”,同时训练价值网络更准。
  4. 反复练习:同一批数据练多轮,逐步优化,最终让模型“又会写、又会估分”。

(实际工程中,还会用clip进一步限制策略变化,让训练更稳定!)

http://www.dtcms.com/a/268335.html

相关文章:

  • 回环检测 Scan Contex
  • DolphinScheduler 3.2.0 后端开发环境搭建指南
  • XML 笔记
  • 极简的神经网络反向传播例子
  • 用户中心Vue3项目开发2.0
  • Docker 容器编排原理与使用详解
  • 125.【C语言】数据结构之归并排序递归解法
  • FileZilla二次开发实战指南:C++架构解析与界面功能扩展
  • 操作系统王道考研习题
  • 76、覆盖最小子串
  • 【STM32】通用定时器PWM
  • 漫漫数学之旅046
  • ThreadLocal的挑战与未来:在响应式编程与虚拟线程中的演变
  • ARMv8 创建3级页表示例
  • 【嵌入式电机控制#11】PID控制入门:对比例算法应用的深度理解
  • Python数据容器-str
  • ch03 部分题目思路
  • 数据驱动实时市场动态监测:让商业决策跑赢时间
  • 端到端矢量化地图构建与规划
  • Solidity——什么是selfdestruct
  • Java线程池知识点
  • RAG技术新格局:知识图谱赋能智能检索与生成
  • 【机器学习笔记Ⅰ】2 线性回归模型
  • 图灵完备之路(数电学习三分钟)----逻辑与计算架构
  • 在phpstudy环境下配置搭建XDEBUG配合PHPSTORM的调试环境
  • ESMFold 安装教程
  • 手动使用 Docker 启动 MinIO 分布式集群(推荐生产环境)
  • list和list中的注意事项
  • 三位一体:Ovis-U1如何以30亿参数重构多模态AI格局?
  • K8s系列之:Kubernetes 的 RBAC (Role-Based Access Control)