当前位置: 首页 > news >正文

【模型训练篇】VeRL核心思想 - 论文HybridFlow

继续学习字节家的VeRL,今天介绍的是VeRL的核心思想,论文 【HybridFlow: A Flexible and Efficient RLHF Framework】,是VeRL的第二篇文章:

  • 底层分布式能力基础Ray(点击查看):VeRL分布式能力的基础,框架Ray
  • VeRL的原理:HybridFlow
  • VeRL的使用,普通RL(单轮RL)
  • VeRL的使用,Agentic RL(多轮RL)
  • VeRL的魔改

RL

先来回顾一下RLHF,当使用 PPO 的时候,需要 actorreferencerewardcritic 四个模型,其中:

  • actorreference,来自 sft 后的模型
  • rewardcritic 则是使用 偏好数据 进行训练好的LLM,只不过输出头替换成了 标量输出

在这里插入图片描述

整个 PPO 过程,可以分成三部分:

  1. Rollout,在 VeRL 中叫 Generation:是用 batch prompts 通过 actor 生成的 response 轨迹数据
  2. 将上述轨迹数据,用 referencerewardcritic 这仨 make experience,在 VeRL 中叫 Preparation
  3. 用上述信息迭代 actorcritic

以上流程很简单本身上是种 DAG workflow,但当下分布式训练中,难点在于 DAG 中的每一个节点,都可以是分布式的,并且还需要使用不同DP策略;

现在市面上,大部分RL框架在执行训练任务的时候:

  • 要么单独使用 single-croller:中心化的进程,管理整个workflow,控制数据、指令,worker可以 MPMD
  • 要么单独使用 multi-controller:去中心化的,每个节点单独处理,属于 自驱,worker是 SPMD

同时大部分RL框架并没有很好的处理一下case:

  1. 存储差异:对于 rewardreference 这俩模型来说,只需要存储它们的权重参数到 GPU 即可,因为它们只做 FWD,而不像 actorcritic 需要存储参数、激活值、梯度、优化器等信息,所以它们对存储的需求差异很大

  2. 阶段差异:即使对于actor这一个模型来说,在rollouttrain这两个阶段的表现也不同,在train阶段属于compute-bound计算密集型,此时可以上大MP,但如果大MP也应用在了rollout阶段就很亏,所以即使同一个模型也有不同阶段不同策略

  3. 部署差异:PPO中有四个模型,它们的部署策略都不同,如下图所示,3台机器6张卡,其中actor部署到01这两张卡上,critic部署到23这两张卡上,refrm部署到45这两卡上,按顺序执行,如果没有数据依赖的可以并行执行,有数据依赖的就等着;

    在这里插入图片描述
    同时论文中也对比了一下,当下不同RL框架的部署策略:

    在这里插入图片描述

HybridFlow

在了解了以上RL的各种问题,以及当下RL框架的劣势后,HybridFlow给出了如下解法:

整体架构上:

要注意的是:

  1. 3D-HybridEngine 专门负责 actorrollouttrain 两个stage,使得actor可以实现不同的3D并行策略,并且高效的在这两个stage间切换
  2. auto mapping 算法是套自研算法可以高效的优化模型部署

整体流程:

  1. 首先需要准备的东西有:
  • PPO的4个模型
  • 用auto-mapping算法结合GPU集群资源,跑出来的模型部署策略
  • 这4个模型在各个阶段中的并行策略
  1. single-controller:拿着上边的信息,去初始化模型、资源池、根据部署策略部署模型、和各个节点的multi-controller交互执行调用
  2. multi-contoller:其实就是各个节点worker,叫做ParallelWorker,它根据每个模型的并行策略,组建parallel group,调用3D-HybridEngine去执行actor的两个stage,也可以整合现有的其他LLM框架进行训练和推理

整体使用上:在这里插入图片描述

  1. 3DParallelWroker是负责实现并行策略的,它根据不同模型的分布式权重初始化策略,组建3DParallelGroup
  2. 3DParallelGroup是一组GPU,负责特定的并行策略:group是一组spmd进程的抽象,像单个程序一个被调用
  3. ActorWroker实现了3DParallelWorker,其他类似的还有实现了PyTorch FSDP的FSDPWorker、Zero的ZeroWroker
  4. @register的作用是:在不同模型跑在不同节点上的不同并行策略间,进行数据传输,所有的数据传输protocol都包含:
  • 1个collect方法用于收集数据:注意dispatch的是input数据,把input数据dispatch到各个worker上
  • 1个distribute方法用于distribute数据,各个worker处理完在通过aggregation进行整合,所以整合的是output数据

上图中Actorupadte_actor使用3D并行策略,所以它使用@register(3D_PROTO)

同时可以看到Actor的并行策略是(1,2,3),而critic的并行策略是(2,1,2)使用了不同的3D并行策略:

  • single-controller通过3D_PROTOactor收集数据的futures对象①②③,因为都是异步的(这里的异步futures估计都是通过Ray实现的)
  • 然后single-controller再把收集到的futures对象,通过critic3D_PROTO把数据发送到critic的每一个DP组④⑤
  • 真实的数据传输只发生在Actor和Critic对应的各GPU rank之间⑥
  1. 将GPU资源初始化成resource_pool后,结合并行策略config,可以实现模型的初始化
  2. 3D-HybridEngine专门负责actorrollouttrain两stage

首先为了不部署重复的actor,VeRL提倡将Actorrollouttrain部署到同一套设备上 colocate/共置,这样在同一台设备上可以直接使用上一轮更新后的权重,尽量省掉权重同步,不然就要存两份;

在这里插入图片描述

p-t-d表示3D策略,ppp 表示模型被拆分成了ppp 个stage,ttt 表示模型tensor被拆分成了 ttt 份,ddd 表示数据被拆分成了 ddd 份,每一份都拥有一个完整的模型,也就是完整的 p∗dp*dpd

  • iii 轮的rollout收集第 i−1i-1i1 轮train actor更新后的模型权重①,这是一个all-gather 操作
  • 然后load prompt后,DP进行rollout拿到轨迹数据,再发送给train阶段进行训练④⑤继续下一轮的迭代

既然actor的rollout和train两个stage可以使用不同的并行策略,但却在同一套set上只部署1份,那么如何处理模型在设备上的冗余?

冗余指的就是图中灰色overlap部分:在这里插入图片描述
就是在两个阶段进行转换的时候,G2/G3/G6/G7灰色那部分,这部分冗余的模型权重就需要占着显存;

所以VeRL优化了这类场景,消除了这部分的冗余,就是上图中 (b),使得两个阶段可以reuse模型权重,尽最大可能overlap,降低显存占用;
7. auto mapping 算法:map的是device和模型的并行部署策略,具体算法就不看了(毕竟他们喜欢出 变种困难leetcode题…)


以上分析了在VeRL出现前RL框架的困境,以及VeRL的解题思路,下篇文章就重点看下使用与代码细节。


文章转载自:

http://sie0F9xH.snrhg.cn
http://UVLEn91Z.snrhg.cn
http://Lwq0pgy0.snrhg.cn
http://UxtKUyEs.snrhg.cn
http://j09nt2Iy.snrhg.cn
http://m3dXL9DX.snrhg.cn
http://5aKkCKcc.snrhg.cn
http://G1C9ekVk.snrhg.cn
http://YCyl6X5b.snrhg.cn
http://QlfzH8hV.snrhg.cn
http://xRksgGms.snrhg.cn
http://AjgGCRPa.snrhg.cn
http://SldEc4zJ.snrhg.cn
http://VDZoJnEf.snrhg.cn
http://YwKPlKQF.snrhg.cn
http://9hAgGhko.snrhg.cn
http://neGn4OXK.snrhg.cn
http://KxEBOmT1.snrhg.cn
http://wWm3nWfv.snrhg.cn
http://zvH3rifB.snrhg.cn
http://57BJMPny.snrhg.cn
http://XBzW0eQN.snrhg.cn
http://y31pw8fk.snrhg.cn
http://BDBd3q8a.snrhg.cn
http://DZk9pimY.snrhg.cn
http://dEW1EwiH.snrhg.cn
http://aaUSgI8d.snrhg.cn
http://RgIb3AHR.snrhg.cn
http://LVFciNPV.snrhg.cn
http://3LqKvl8a.snrhg.cn
http://www.dtcms.com/a/372100.html

相关文章:

  • pycharm设置编辑区字体大小
  • 鸿蒙NEXT跨设备数据同步实战:分布式应用开发指南
  • C++ 中栈 (Stack) 详解和常见面试示例汇总实现
  • [光学原理与应用-461]:波动光学 - 波片实现偏振态的转换或调整
  • 苍穹外卖Day12 | Apache POI、导出Excel报表、HttpServletResponse、工作台
  • 《Go小技巧易错点100例》第三十八篇
  • Conda 包管理器与环境管理使用指南
  • 笔记本、平板如何成为电脑拓展屏?向日葵16成为副屏功能一键实现
  • OpenHarmony 显示能效管理组件:掌控屏幕亮灭与亮度的核心利器
  • SQLite的基本操作
  • 第五课 C#语言基本元素概览,初始类型,变量与方法,算法简介
  • 【系统分析师】第12章-关键技术:软件架构设计(核心总结)
  • Lightdash:一个免费开源的自助式BI平台
  • Claude Code 使用教程
  • UML(统一建模语言)
  • Android开发-常用布局
  • Spring Cloud Gateway 进行集群化部署
  • EmbodiedOneVision——类似π0.5集成了离散自回归解码与连续流匹配去噪:单个模型中完成具身推理、动作生成
  • Paper reading - 03. Speech sequencing in the human precentral gyrus
  • Spring事务失效的常见陷阱与解决方案
  • 现代C++:现代C++?
  • ZSet
  • Linux初级篇
  • MySQL集群高可用架构——组复制 (MGR)
  • MySQL Cluster核心优缺点
  • RestTemplate使用 | RestTemplate设置http连接池参数
  • 01OpenCV简介
  • 美股市场股票数据API对接文档
  • Coze源码分析-资源库-删除插件-前端源码-核心接口与工具
  • 【深度学习】重采样(Resampling)