当前位置: 首页 > news >正文

从TRPO到GRPO

Deepseek最近在 AI 社区引起了不小的轰动,这要归功于它以相对较低的成本实现了令人印象深刻的性能。我认为这是深入了解大型语言模型 (LLM) 训练方式的绝佳机会。在本文中,我们将重点介绍强化学习 (RL) 方面的内容:我们将介绍 TRPO、PPO 以及最近的 GRPO(别担心,我很快就会解释所有这些术语!) 

我的目标是通过尽量减少数学知识来让这篇文章相对容易阅读和理解,这样你不需要有深厚的强化学习背景就可以理解。不过,我假设你对机器学习、深度学习有一定的了解,并且对 LLM 的工作原理有基本的了解。

LLM 训练的 3 个步骤(更多了解参考上一篇博客)

在深入研究 RL 细节之前,让我们简要回顾一下训练大型语言模型的三个主要阶段:

  • 预训练:在海量数据集上训练模型,根据前面的标记预测序列中的下一个标记。
  • 监督微调 (SFT):然后根据更有针对性的数据对模型进行微调并与特定指令保持一致。
  • 强化学习(通常称为RLHF,即基于人类反馈的强化学习):这是本文的重点。主要目标是通过让模型直接从反馈中学习,进一步完善响应与人类偏好的一致性。因此,我的个人理解,强化学习只是提升模型回答准确性的一种微调手段,不要把它想象的太厉害。

强化学习基础

在深入探讨之前,让我们简单回顾一下强化学习背后的核心思想。

从高层次上理解强化学习非常简单:代理与环境交互。代理处于环境中的特定状态,并可以采取行动过渡到其他状态。每个动作都会从环境中获得奖励:这就是环境提供反馈以指导代理未来行动的方式。 

考虑以下示例:机器人(代理)导航(并尝试退出)迷宫(环境)。

  • 状态是环境的当前情况(机器人在迷宫中的位置)。
  • 机器人可以采取不同的动作:例如,它可以向前移动、左转或右转。
  • 成功导航至出口会带来积极的奖励,而撞到墙壁或被困在迷宫中则会带来负面的奖励。

简单!现在,让我们类比一下 RL 在 LLM 中的使用方式。

LLM背景下的 RL

在 LLM 培训期间使用时,RL 由以下组件定义:

  • LLM 本身就是代理
  • 环境:LLM 外部的一切,包括用户提示、反馈系统和其他上下文信息。这基本上是 LLM 在训练期间与之交互的框架。
  • 动作:这些是对模型查询的响应。更具体地说:这些是LLM 决定为响应查询而生成的标记。
  • 状态:正在回答的当前查询以及 LLM 迄今为止生成的标记(即部分响应)。
  • 奖励:这有点棘手:与上面的迷宫示例不同,通常没有二元奖励。在 LLM 的背景下,奖励通常来自单独的奖励模型,该模型为每个(查询,响应)对输出一个分数。该模型是根据人工注释的数据(因此称为“RLHF”)训练的,其中注释者对不同的响应进行排名。目标是让更高质量的响应获得更高的奖励。

注意:在某些情况下,奖励实际上可以变得更简单。例如,在 DeepSeekMath 中,可以使用基于规则的方法,因为数学响应往往更具确定性(正确或错误答案)

策略是我们现在需要的最后一个概念。在 RL 术语中,策略只是决定采取什么行动的策略。在 LLM 的情况下,策略会在每个步骤中输出可能标记的概率分布:简而言之,这就是模型用来采样下一个要生成的标记的内容。具体来说,策略由模型的参数(权重)决定。在 RL 训练期间,我们会调整这些参数,以便 LLM 更有可能生成“更好”的标记 - 即产生更高奖励分数的标记。

我们经常把策略写成:

其中a是动作(要生成的标记),s 是状态(迄今为止生成的查询和标记),θ(模型的参数)。即智能体在状态s下,选择动作a是不确定性的,而是服从某种概率分布。在参数θ下,状态s时选择动作a的条件概率。

找到最佳策略的想法是 RL 的重点!由于我们没有标记数据(就像我们在监督学习中那样),我们使用奖励来调整我们的策略以采取更好的行动。 (用 LLM 术语来说:我们调整 LLM 的参数以生成更好的标记。)

TRPO(信任区域策略优化)

与监督学习的类比

让我们快速回顾一下监督学习通常是如何工作的。您已经标记了数据并使用损失函数(如交叉熵)来衡量模型的预测与真实标签的接近程度。

然后,我们可以使用反向传播和梯度下降等算法来最小化我们的损失函数并更新模型的权重θ 。

回想一下,我们的策略也输出概率!从这个意义上讲,它类似于监督学习中模型的预测……我们倾向于写类似这样的内容

其中s是当前状态,a是可能的动作。

A(s, a)称为优势函数,用于衡量当前状态下所选动作与基线相比有多好。这非常类似于监督学习中的标签概念,但它源自奖励而不是显式标记。为了简化,我们可以将优势写为:

在实践中,基线是使用价值函数来计算的。这是强化学习中的一个常用术语,我稍后会解释。你现在需要知道的是,它衡量了如果我们继续遵循当前策略,我们将从状态s开始获得的预期奖励。

什么是 TRPO?

TRPO(信赖区域策略优化)建立在使用优势函数的思想之上,但增加了一个稳定性的关键因素:它限制了新策略在每个更新步骤中偏离旧策略的程度(类似于我们对批量梯度下降所做的那样)。

  • 它在当前策略和旧策略之间引入了一个 KL 散度项(将其视为相似性的度量):

  • 它还用新策略除以旧策略。这个比率乘以优势函数,让我们了解每次更新相对于旧策略的益处。

总而言之,TRPO 试图在KL 散度约束下最大化替代目标(涉及优势和策略比率)。

PPO(近端策略优化)

虽然 TRPO 是一项重大进步,但由于其计算密集型的梯度计算,它在实践中已不再被广泛使用,尤其是在训练 LLM 方面。

相反,PPO 现在是大多数 LLM 架构的首选方法,包括 ChatGPT、Gemini 等。

它实际上与 TRPO 非常相似,但PPO不是对 KL 散度施加硬约束,而是引入了“修剪替代目标”,隐式限制策略更新,并大大简化了优化过程。

以下是我们为调整模型参数而最大化的 PPO 目标函数的细分。

GRPO(组相对策略优化)

价值函数通常如何获得?

我们首先来详细谈谈我之前介绍的优势价值功能。

在典型设置(如 PPO)中,价值模型与策略一起训练。其目标是使用我们获得的奖励来预测我们采取的每个动作(模型生成的每个标记)的价值(请记住,该价值应代表预期的累积奖励)。

它在实践中的运作方式如下。以查询“2+2 等于多少?”为例。我们的模型输出“2+2 等于 4”,并针对该响应获得 0.8 的奖励。然后我们回溯并将折扣奖励归因于每个前缀:

  • “2+2 等于 4” 的值是 0.8
  • “2+2 is”(向后 1 个 token)的值为0.8γ
  • “2+2”(向后 2 个 token)的值为 0.8 γ²
  • ETC。

其中γ是折扣因子(例如 0.9)。然后我们使用这些前缀和相关值来训练价值模型。

重要提示:价值模型和奖励模型是两个不同的东西。奖励模型在强化学习过程之前进行训练,并使用(查询、响应)和人工排名对。价值模型与策略同时进行训练,旨在预测生成过程每一步的未来预期奖励。

GRPO 中的新功能

即使在实践中,奖励模型通常源自策略(仅训练“头部”),但我们最终仍需要维护许多模型并处理多个训练程序(策略、奖励、价值模型)。GRPO通过引入更有效的方法来简化这一过程。

还记得我之前说过的话吗?

在 PPO 中,我们决定使用我们的价值函数作为基线。GRPO 选择了其他东西:GRPO 的作用如下:具体来说,对于每个查询,GRPO 生成一组响应(大小为 G 的组)并使用它们的奖励来计算每个响应的优势作为z 分数

其中rᵢ是第 i个响应的奖励,μσ是该组奖励的平均值和标准差。

这自然就消除了对单独价值模型的需求。仔细想想,这个想法很有意义!它与我们之前介绍的价值函数一致,并且在某种意义上衡量了我们可以获得的“预期”奖励。此外,这种新方法非常适合我们的问题,因为 LLM 可以通过使用低温控制 token 生成的随机性)轻松生成多个非确定性输出。

这是 GRPO 背后的主要思想:摆脱价值模型。

最后,GRPO 将KL 散度项(确切地说,GRPO 使用 KL 散度的简单近似来进一步改进算法)直接添加到其目标中,将当前策略与参考策略(通常是后 SFT 模型)进行比较。

最终公式如下:

这就是 GRPO 的大部分内容!我希望这能让您清楚地了解该过程:它仍然依赖于与 TRPO 和 PPO 相同的基础思想,但引入了其他改进,使训练更高效、更快、更便宜——这是 DeepSeek成功的关键因素。

结论

强化学习已成为当今训练大型语言模型的基石,特别是通过 PPO 以及最近的 GRPO。每种方法都基于相同的 RL 基本原理(状态、动作、奖励和策略),但增加了自己的变化以平衡稳定性、效率和人机协调性:

• TRPO通过 KL 散度引入了严格的策略约束

• PPO通过简洁的目标缓解了这些限制

• GRPO采取了额外的措施,取消了价值模型要求,并使用了基于组的奖励规范化。当然,DeepSeek 还受益于其他创新,例如高质量数据和其他训练策略,但那是另一回事了!

我希望本文能让您更清楚地了解这些方法是如何联系和发展的。我相信强化学习将成为训练 LLM提高其性能 的主要焦点,超越预训练和 SFT 来推动未来的创新。

下一篇将介绍强化学习在多模态领域的应用。敬请期待!

相关文章:

  • scikit-surprise 智能推荐模块使用说明
  • 简单视图函数
  • (BFS)题解:P9425 [蓝桥杯 2023 国 B] AB 路线
  • 智能打印预约系统:微信小程序+SSM框架实战项目
  • 机器学习的一百个概念(6)最小最大缩放
  • Codeforces Round #1014 (Div. 2)
  • 三路排序算法
  • 本科lw指导
  • 鸿蒙NEXT开发Base64工具类(ArkTs)
  • 消息队列--RocketMQ
  • DeepSeek 助力 Vue3 开发:打造丝滑的表格(Table)之添加行拖拽排序功能示例13,TableView16_13 键盘辅助拖拽示例
  • 【算法】快速幂
  • 6内存泄露问题的讨论
  • MySQL其他客户端程序
  • 边缘计算:工业自动化的智能新引擎
  • 低成本文件共享解决方案:Go File本地Docker部署与外网访问全记录
  • 小米平板 4 Plus 玩机日志
  • Xvfb和VNC Server是什么
  • 使用自定义的RTTI属性对对象进行流操作
  • 7对象树(1)
  • 常州微信网站建设/360优化大师最新版下载
  • 工商注册服务平台/重庆百度seo
  • 唐山网络运营推广/简述seo
  • 网站开发毕业指导手册/网页设计期末作业模板
  • 湖北网站建设费用/网络销售就是忽悠人
  • 网站做打鱼游戏挣钱吗/seo网络推广机构