当前位置: 首页 > news >正文

强化学习系列--从数值出发,解读 DPO 训练背后的偏好优化逻辑

DPO 概要

  1. DPO(Direct Preference Optimization,直接偏好优化)是由斯坦福大学等研究团队于2023年提出的一种偏好优化算法,可用于LLM、VLM与MLLM的对齐训练。

  2. 算法基于PPO的RLHF基础上进行了大幅简化。DPO算法跳过了训练奖励模型这一中间过程,直接(Direct)优化策略模型 ——这正是DPO命名中“D(Direct)”的含义所在。

主要流程

  1. 数据收集: 基于SFT训练的模型作为推理模型,用户输入prompt,模型多次推理,找到好的答案和不好的答案。如果都是不好(rejected)的答案,则人工修改把不好的答案变为好的答案。

    标数据收集
  2. 主要包含两个基础模型,策略模型&参考模型(不需要Reward模型)。 在trl强化学习框架中,只需要传入策略模型,参考模型会复制一份策略模型。

    1. 策略模型是DPO需要训练的模型,后用在项目中的模型。策略模型的权重直接复制SFT阶段微调模型的权重

    2. 参考模型是策略模型的帮衬,其权重参数冻结不变。主要两个作用,其一协助其计算reward loss,其二计算kl正则项,防止其训练偏移初始SFT模型太远,由一个β参数控制。

  3. β参数控制含义

    1. 较大 beta(如 1.0):放大 reward 或 logp 的差异,使模型更“自信”地倾向于较优样本,但容易过拟合或 reward 震荡。

    2. 较小 beta(如 0.1):差异被压缩,模型训练更稳定,但收敛较慢、辨别力较弱。

    3. 极小 beta(趋近于 0):差异几乎无效,模型无法区分好坏样本,退化为随机训练

  4.  整体流程如下:

  5. 具体流程

    DPO训练流程细节

九个损失函数解析

"loss": 1.8678"rewards/chosen": 42.519317626953125"rewards/rejected": -33.865535736083984"rewards/accuracies": 0.865429699420929"rewards/margins": 76.38734436035156"logps/chosen": -948.4149780273438"logps/rejected": -1285.1175537109375"logits/chosen": 5.363300800323486"logits/rejected": 4.879658222198486
  1. logps/chosen和logps/rejected: logps 是模型生成 token 概率,在归一化后(softmax)取 log 后的值(log prob)。

    #1 把 prompt 和 response 拼接起来作为输入
    input = prompt + response
    from transformers import AutoTokenizer, AutoModelForCausalLM
    import torch# 加载 tokenizer 和模型
    tokenizer = AutoTokenizer.from_pretrained("your-model-name")
    model = AutoModelForCausalLM.from_pretrained("your-model-name").cuda()# 设置 prompt 和 response
    prompt = "你今天心情怎么样?"
    response = "我今天很开心,太阳出来了,我们一起去玩吧!"# 拼接输入
    full_input = prompt + response
    encodings = tokenizer(full_input, return_tensors="pt").to("cuda")
    input_ids = encodings["input_ids"]# 找到 response 的起始位置
    prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
    response_start = prompt_ids.shape[-1]# 前向推理,获取 logits
    with torch.no_grad():outputs = model(**encodings)logits = outputs.logits# 计算 log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)# 获取 response 部分 token 的 log probability
    response_token_ids = input_ids[:, response_start:]
    response_logits = log_probs[:, response_start - 1:-1, :]  # 对应 shift
    response_logp = torch.gather(response_logits, 2, response_token_ids.unsqueeze(-1)).squeeze(-1)# 平均 log probability(整个 response)
    logp_response = response_logp.mean()logps_chosen = compute_logp(prompt, chosen, actor_model)
    logps_rejected = compute_logp(prompt, rejected, actor_model)
    logps_ref_chosen = compute_logp(prompt, chosen, ref_model)
    logps_ref_rejected = compute_logp(prompt, rejected, ref_model)
  2. logits/chosen和logits/rejected: 模型输出的raw score(未进行归一化)求平均

    # 模型输出:logits = [batch_size, seq_len, vocab_size]
    # 获取 chosen 的最后一个 token 的 logit:
    logit_chosen = logits[:, -1, :]  # 通常是这个位置
    logits/chosen = logit_chosen.mean().item()
    # 拿出 chosen response 部分的 token 对应的 logit 向量
    logits_response = logits[:, prompt_len:, :]  # mask 掉 prompt 部分
    logits/chosen = logits_response.mean().item()
  3. reward 计算方法

    chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
    rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
    reward_accuracies = (chosen_rewards > rejected_rewards).float()
    metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
    metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
    metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
    metrics[f"{prefix}rewards/margins"] = (
    self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
  4. Loss 计算方法

    本次默认使用sigmoidlogratios = chosen_logps - rejected_logpsref_logratios = ref_chosen_logps - ref_rejected_logps                logratios = logratios.to(self.accelerator.device)ref_logratios = ref_logratios.to(self.accelerator.device)logits = logratios - ref_logratios losses = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)- F.logsigmoid(-self.beta * logits) * self.label_smoothing )
    其他计算方法如下(后续介绍):"hinge","ipo",
    "exo_pair","nca_pair","robust","bco_pair",
    "sppo_hard","aot","apo_down""aot_pair","apo_zero","discopop",
  5. 关系理解

    指标

    含义

    关系

    logits

    每个 token 的原始输出分数(未归一化)

    模型输出的raw score(未进行归一化)求平均

    logps

    所有 token 的 log 概率之和(对 logit softmax 后求 log,token-wise 累加)

    来自 logits → softmax → log(prob) → sum over tokens

    rewards

    在 logp-based reward 情况下,reward 就是 sum(logps)/len(tokens)

    eval_rewards/chosen == eval_logps/chosen/len(tokens)

  6. 主要关注指标

    指标名

    含义

    影响

    loss

    当前 batch 的 DPO/IPO 损失值

    反映训练是否有效收敛,是否有发散/震荡

    rewards/margins

    reward_chosen - reward_rejected 的平均值

    反映模型区分正负样本的能力是否提升

    rewards/accuracies

    reward_chosen > reward_rejected 的比例

    反映偏好判断正确率是否提高

    logs/chosen& logs/rejected

    每个 sample 的对数似然总和

    趋势变化判断 token-level 拟合趋势

  7. 总结

    1. log 概率指标(logps)
    指标含义用途正常范围是否用于 loss
    "logps/chosen"chosen 回复中每个 token 的 log(prob),总和衡量模型对 chosen 回复的生成概率负值,长度越长越负✅ 核心用于 DPO loss:log p(chosen) - log p(rejected)
    "logps/rejected"同上,对 rejected 回复衡量模型对差回复的置信度负值✅ 同上

     "logps/chosen": -948.41
    "logps/rejected": -1285.11

    说明:模型对 chosen 回复的概率更高(log 概率更接近 0,越好),这说明模型学会了偏好优选答案。

            2. 核心奖励相关指标 (rewards)

    指标含义用途正常范围是否用于 loss
    "rewards/chosen"reward 模型给 chosen 回复打的分数衡量优选回答质量通常 > 0,越大越好✅ 是 loss 的一部分(间接参与)
    "rewards/rejected"reward 模型给 rejected 回复打的分数衡量差回复的质量通常 < 0,越小越好✅ 是 loss 的一部分
    "rewards/accuracies"chosen > rejected 的比例(准确率)衡量 reward 模型偏好正确性0~1,越高越好❌ 不用于 loss,训练监控用
    "rewards/margins"rewards/chosen - rewards/rejected 的 margin衡量 chosen 比 rejected 更优多少> 0,越大越好❌ 不直接用于 loss,但用于评估 reward 区分度

    "rewards/chosen": 42.52
    "rewards/rejected": -33.87
    "rewards/margins": 76.39
    "rewards/accuracies": 0.865
    说明:reward 模型倾向于更高地打分给 chosen(准确率 86.5%,平均 margin 达 76),这是训练良好的表现

    3. logits 指标

    指标含义用途是否用于 loss
    "logits/chosen"模型在输出 chosen 回复时的 token-level logit 平均值(未 softmax 前)可作为置信度估计参考❌ 仅用于辅助评估,不参与 loss
    "logits/rejected"同上,对 rejected同上

    "logits/chosen": 5.36
    "logits/rejected": 4.88
    说明:模型在生成 chosen 时激活程度略高,但不是主要指标,通常 logps 更重要

    4.   损失指标 (loss)    

    指标含义范围趋势
    "loss"当前 batch 的 DPO / IPO / KTO loss正值,越小越好随训练逐渐下降

             "loss": 1.8678

    表示当前 batch 的训练损失,值适中。如果长时间不下降或上升,需要检查 learning rate 或 reward model 的稳定性

    其他思考

    1.  logps/chosen是负的合理吗

    logps(y_{chosen}|x})logps(y_{chosen}|x}) 是模型对生成chosen回复时,每个token的概率取对数后加总, 由于每一个token的概率 ,所以。p(yt,y<t)∈(0,1),所以logp(yt)<0。 所以累加一段文本后,整个logp通常是一个比较大的负值。

    2. reward为负值

    因为是 rchosen=logπθ(ychosen|x) ,如果没有额外reward打分模型,则 r=sum(logps)/len(logps)

    3. 基于loss 计算ppl

    •     在语言模型中,**PPL(Perplexity,困惑度)**是衡量模型预测能力的重要指标,通常用于评估语言建模任务中模型输出的流畅性和合理性。
    • 如果你已经得到了训练或评估中的 loss 值(特别是 cross-entropy losslog-likelihood loss),可以直接通过下面的公式计算 PPL
    • 注意事项: 
      • 在 DPO / IPO 等偏好优化任务中,loss 并非标准语言建模 loss,此时 PPL 没有经典意义
      • 如果你计算了某模型在验证集上的 语言建模 loss(例如对 gold 回复计算),此时用来算 PPL 是合理的;
    • 举例

    "loss": 1.8678

    import math
    loss = 1.8678
    ppl = math.exp(loss)
    print(ppl)  # 输出约为 6.47

    说明:当前模型平均每个 token 的预测有大约 6.47 倍的不确定性。

    http://www.dtcms.com/a/264199.html

    相关文章:

  8. Navicat Premium x TiDB 社区体验活动 | 赢 Navicat 正版授权+限量周边+TiDB 社区积分
  9. 第8章路由协议,RIP、OSPF、BGP、IS-IS
  10. RabbitMQ简单消息监听
  11. 基于开源AI大模型AI智能名片S2B2C商城小程序的流量转化与价值沉淀研究
  12. linux魔术字定位踩内存总结
  13. 振荡电路Multisim电路仿真实验汇总——硬件工程师笔记
  14. MySQL 常用命令大全
  15. 0.96寸OLED显示屏 江协科技学习笔记(36个知识点)
  16. swing音频输入
  17. sqlmap学习ing(2.[第一章 web入门]SQL注入-2(报错,时间,布尔))
  18. jQuery 安装使用教程
  19. MySQL数据一键同步至ClickHouse数据库
  20. 前端第二节(Vue)
  21. 橙心同步助手2.0.1版本更新
  22. Instruct-GPT中强化学习(RL)训练部分详解
  23. Android实现仿iOS风格滚动时间选择器
  24. 零信任安全管理系统介绍
  25. 新版本 Spring Data Jpa + QueryDSL 使用教程
  26. Java基础 集合框架 抽象类 AbstractList
  27. Bootstrap 安装使用教程
  28. 三极管是NPN还是PNP
  29. CppCon 2018 学习:EMULATING THE NINTENDO 3DS
  30. 以下是 Kafka 不同认证方式的配置示例,结合前面的单表设计方案,展示如何为每种认证方式填充配置表
  31. Docker进阶命令与参数——AI教你学Docker
  32. 第八十六篇 大数据排序算法:从厨房整理到分布式排序的智慧
  33. MS1826+LT8644 4K@30Hz HD8×8/16×16高清矩阵
  34. 数据结构复习5
  35. 数字ic后端设计从入门到精通10(含fusion compiler, tcl教学)静态时序分析
  36. 使用Ansible的playbook安装HTTP
  37. 8.4 Jmter实践不同线程组之间的全局变量的传递和使用