当前位置: 首页 > news >正文

【RL第三篇】REINFORCE Leave-One-Out(RLOO)算法(基于留一法的REINFORCE策略梯度算法)

一、前言

Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs

paper: https://arxiv.org/pdf/2402.14740

提出了基于REINFORCE的RLOO强化学习算法(REINFORCE Leave-One-Out)

二、RLOO

2.1 REINFORCE Baseline

from:https://people.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf

对于基线,无参数选择是利用过去价值的移动平均值作为baseline, 即Moving Average Baseline,可以是有窗口的滑动,也可以是无窗口的滑动,比如无窗口的滑动,会把训练历史中所有的回报做平均:

bMA=1S∑sR(τ)b_{\text{MA}} = \frac{1}{S} \sum_{s} R(\tau)bMA=S1sR(τ)

其中S为训练步骤,R为s步骤下的回报。

也可以是指数滑动平均baseline(指数滑动平均blog:https://zhuanlan.zhihu.com/p/670490330),即Exponential Moving Average Baseline:

bEMA=α⋅bEMA+(1−α)⋅Rˉ(τ) b_{\text{EMA}} = \alpha \cdot {b_{\text{EMA}}} + (1 - \alpha) \cdot \bar R(\tau) bEMA=αbEMA+(1α)Rˉ(τ)

其中,α\alphaα 是滑动平均的衰减因子,通常设为 0.9 或 0.99,Rˉ(τ)\bar R(\tau)Rˉ(τ) 是当前训练回合的平均回报(通过当前步骤的多次轨迹采样计算回报的平均值)。

2.2 REINFORCE Leave-One-Out

然而类似于取移动平均作为baseline,依旧方差很大。相比于REINFORCE,RLOO核心实现细节在于,它采用批次中其他样本的平均奖励来计算基线,而不是对批次中的所有奖励取平均值。

  • 针对response level的reward,action=轨迹,对于llm任务来讲,prompt为state,response为action,一个response一个reward,此时

R(τ)=r(s,a)R(\tau) = r(s, a)R(τ)=r(s,a)

  • 需要对于一个prompt采样生成多个相应(state->多个action/多个轨迹)

对于RLOO基线,给定 KKK 个采样轨迹或动作 a1,…,aKa_1, \ldots, a_Ka1,,aK,对于给定的提示 sss,每个提示的基线为:

b(s,ak)=1K−1∑i=1,i≠kKr(s,ai) b(s, a_k) = \frac{1}{K - 1} \sum_{\substack{i=1, i \neq k}}^{K} r(s, a_i) b(s,ak)=K11i=1,i=kKr(s,ai)

从而带来每个提示的优势:

A(s,ak)=r(s,ak)−b(s,ak) A(s, a_k) = r(s, a_k) - b(s, a_k) A(s,ak)=r(s,ak)b(s,ak)

等效地,这可以表示为:

A(s,ak)=KK−1(r(s,ak)−1K∑i=1Kr(s,ai)).(21) A(s, a_k) = \frac{K}{K - 1} \left( r(s, a_k) - \frac{1}{K} \sum_{i=1}^{K} r(s, a_i) \right). \tag{21} A(s,ak)=K1K(r(s,ak)K1i=1Kr(s,ai)).(21)

三、代码理解

import torch
local_batch_size = 3
rloo_k = 4rlhf_reward = torch.tensor([1, 2, 3, # first rlhf reward for three prompts2, 3, 4, # second rlhf reward for three prompts5, 6, 7, # third rlhf reward for three prompts8, 9, 10, # fourth rlhf reward for three prompts
]).float() # here we have 3 prompts which have 4 completions each# slow impl
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
advantages = torch.zeros_like(rlhf_reward)
for i in range(0, len(advantages), local_batch_size):other_response_rlhf_rewards = []for j in range(0, len(advantages), local_batch_size):if i != j:other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(other_response_rlhf_rewards).mean(0)
assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6
assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6# vectorized impl
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)

batch_size = 3, 三个不一样的prompt,每个prompt生成4个response

刚开始baseline为所有reward的mean值,可以具体Debug看step1的各个相关值:

rlhf_reward.sum(0)为,每个prompt维度reward之和

tensor([16., 20., 24.])

baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1) 其中的rlhf_reward.sum(0) - rlhf_reward为相当于每个response单独减去本身的reward(公式为i≠ki \neq ki=k), 最终除以(rloo_k - 1),每个prompt生成四个response,则rloo_k应该设置为4。

这样是不是对RLOO的优化更理解了。

Ref

  • https://rlhfbook.com/c/11-policy-gradients.html#reinforce
  • https://huggingface.co/blog/zh/putting_rl_back_in_rlhf_with_rloo
http://www.dtcms.com/a/301564.html

相关文章:

  • RK3568基于mpp实现硬解码(一):mpp库的编译使用
  • [每周一更]-(第151期):Go语言中的Map、Slice、Array和Hash原理详解
  • 博士招生 | 香港大学 招收人工智能和网络安全方向 博士生
  • 7.27 状态机dp|质数线性筛|序列化树
  • Linux网络-------2.应⽤层⾃定义协议与序列化
  • SpringBoot实现Serverless:手撸一个本地函数计算引擎
  • mcu trace工具调研
  • elasticsearch 倒排索引原理详解
  • SpringBoot3整合Redis
  • 零基础学习性能测试第五章:性能瓶颈分析与调优-网络资源瓶颈分析与优化建议
  • Python调用大模型api并部署到前端的主流技术栈以及具体框架对比
  • 【牛客网C语言刷题合集】(四)
  • Java类加载器与双亲委派模型
  • n8n “Run Once for All Items“和“Run Once for Each Item“区别
  • 深度学习中的计算图与自动微分原理:静态图与动态图的实现差异
  • sd Function 学习笔记
  • BeautifulSoup 使用详解与实战示例
  • WAIC 2025 热点解读:如何构建 AI 时代的“视频神经中枢”?
  • WordPress 网站中的“mu-plugins”隐藏后门
  • [每周一更]-(第152期):Go中的CAS(Compare-And-Swap)锁原理详解
  • Java面试宝典:MySQL性能优化
  • ES6模块详解:核心语法与最佳实践
  • 编码器和解码器风格的Transformer架构
  • 使用vue2和 element-ui 做一个点餐收银台系统前端静态项目
  • 数据江湖的“三国演义”:数据仓库、数据湖与湖仓一体的全景对比
  • Gradio全解8——ChatInterfaceChatbot:聊天界面类与聊天机器人(4)——返回复杂响应与直接修改Chatbot值
  • Java Ai(day03)
  • 【秋招笔试】7月26日科大讯飞秋招第一题
  • 【最新最完整】SpringAI-1.0.0开发MCP Server,搭建MCP Client 实战笔记(进阶+详细+完整代码)
  • AI Agent学习