当前位置: 首页 > news >正文

huggingface TRL中的对齐算法: KTO

KTO

1 背景

• 传统 RLHF / DPO / IPO 都需要「成对偏好」数据 (x, y_w, y_l)。
• 真实业务日志往往只有单条样本加二值标签(good / bad),收集成对数据成本高且周期长。
• KTO 受前景理论(Prospect Theory)启发,用「收益-损失」框架刻画人类对单条结果的主观效用,从而支持仅二元标签即可离线对齐大模型 。

──────────────────

2 算法流程(离线单轨迹)

输入
• 数据集 D = {(x, y, b)} ,其中 b = 1 表示“好”,b = 0 表示“坏”。
• 参考策略 π_ref(通常由 SFT 模型初始化)。

步骤

  1. 对每个样本计算
    r_θ(x, y) = log π_θ(y|x) – log π_ref(y|x) (对数似然比)

  2. 估计参考点(baseline)
    z₀ = KL(π_θ||π_ref) 的滑动窗口近似,实践中常用 mini-batch 平均 r_θ 代替 。

  3. 根据 y 的类别计算前景价值 v(x, y):

    v(x, y) =
    ┌ λ_D · σ(β (r_θ – z₀)) 若 b = 1 (好样本)
    └ λ_U · σ(β (z₀ – r_θ)) 若 b = 0 (坏样本)

    σ 为 Logistic 函数,β > 0 控制敏感度,λ_D、λ_U 为正负样本权重(通常 λ_U > λ_D 以反映损失厌恶)。

  4. 最小化人类对齐损失目标(HALO):

    L_KTO = – 𝔼_D [ v(x, y) ]

  5. 反向传播更新 π_θ 参数即可,无需显式奖励模型,也无需在线采样。

──────────────────

3 关键公式汇总

符号说明来源
r_θ(x,y) = log π_θ(yx) – log π_ref(yx)
z₀ = 𝔼[ r_θ(x, y′) ] ≈ KL(π_θπ_ref)
v(x,y)=λ_D σ(β(r_θ−z₀)) (好)前景价值(收益段)(6)
v(x,y)=λ_U σ(β(z₀−r_θ)) (坏)前景价值(损失段)(6)
L_KTO = − 𝔼_D[ v(x,y) ]最终损失(2)

──────────────────

4 与 DPO、PPO 的对比

维度KTODPOPPO
数据需求单条+二元标签成对偏好在线轨迹
奖励模型无(隐式)无(隐式)需显式 RM
训练方式离线离线在线 RL
核心思想前景理论效用Bradley-Terry 偏好排序最大化奖励+KL 约束

──────────────────

5 工程实现要点

  1. β 默认 0.1;λ_U 设成 λ_D 的 2–3 倍以匹配人类损失厌恶系数 。
  2. z₀ 每 N 步更新一次即可,避免训练不稳定。
  3. 若正负样本不平衡,可把 λ_D、λ_U 改成基于类别频率的逆比例权重。

至此,KTO 算法的完整背景、推导和公式已全部列出,可直接用于代码实现或论文写作。

6. TRL里怎么实现的

下面逐行拆解这个 get_batch_loss_metrics 函数,说明它在做什么、为什么要这么做,以及可能的注意点。


1. 函数签名

def get_batch_loss_metrics(self,model,batch: dict[str, Union[list, torch.LongTensor]],):
  • 目的:给定一个 batch 的输入,计算 KTO(Kahneman-Tversky Optimization)损失 及其附属指标。
  • 输入
    • model:当前策略网络(policy model)。
    • batch:字典,字段通常包括:
      • input_ids, attention_mask, label, 可能还有 reference_logps, reference_KL_logps 等。

2. 设备搬运

batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v)for k, v in batch.items()}
  • 把所有张量搬到 accelerator 所在的设备(GPU/CPU)。

3. 统计 chosen / rejected 样本数

labels = torch.tensor(batch["label"])
num_chosen = labels.sum()
num_rejected = len(labels) - num_chosen
  • label==1 → chosen(人类偏好较高的回答)。
  • label==0 → rejected。

4. 损失计算分支:use_liger_loss

if self.args.use_liger_loss:...
else:...
  • liger 分支:
    使用 Liger-Kernel 的高度融合 CUDA kernel,一次前向即可拿到所有所需张量(logpslogitsrewardsKL)。
  • else 分支:
    传统做法:先跑 forward,再跑 reference 模型,再算损失。

5. Reference Logps 的两种来源

if "reference_logps" in batch:...
else:with torch.no_grad():...
  • 离线缓存(推荐):训练前就预先把参考模型的 log p(y|x) 算好,避免每个 step 都跑一次参考模型,速度大幅提升。
  • 实时推理:没有离线缓存时,用 ref_model(或 model 开启 null_ref_context)再跑一次 forward。

6. 计算 KTO Loss

losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(...)
  • 核心函数,实现 KTO 论文里的 prospect-theoretic utility 公式:
    • chosen_rewards = policy_chosen_logps - reference_chosen_logps
    • rejected_rewards = policy_rejected_logps - reference_rejected_logps
    • KL 项控制与参考策略的偏离度。

7. 分布式指标收集

all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
...
  • 在多卡/多节点环境下,把所有卡的 chosen、rejected 计数、reward、logp 累加,保证指标 全局一致

8. 最终损失 + Aux Loss

loss = losses.nanmean()
if self.aux_loss_enabled:loss += self.aux_loss_coef * aux_loss
  • nanmean:忽略 NaN 的样本(可能出现空 chosen 或空 rejected)。
  • 如有 辅助损失(比如 MoE load balancing loss),加到总损失。

9. 返回值

return loss, metrics
  • loss:用于反向传播的张量。
  • metrics:字典,记录:
    • kl:平均 KL。
    • rewards/chosen_sum, logps/chosen_sum, count/chosen(chosen 样本)。
    • rewards/rejected_sum, logps/rejected_sum, count/rejected(rejected 样本)。

✅ 一句话总结

这个函数在 Deepspeed / Accelerate 分布式环境下,统一完成了:

  1. 数据搬运 →
  2. 策略模型/参考模型前向 →
  3. 计算 KTO loss + 指标 →
  4. 跨进程聚合 →
  5. 返回可反向传播的 loss 和日志。

逻辑清晰,兼顾了 性能(liger kernel / 离线缓存)分布式正确性

这个 kto_loss 函数是 KTO(Kahneman-Tversky Optimization)损失的核心实现,也顺带兼容了 APO(Anchored Preference Optimization) 的 unpaired 变体。下面逐段拆解:


✅ 1. KL 估计(可选)

if self.calculate_KL:kl = (policy_KL_logps - reference_KL_logps).mean().detach()kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
  • 整批样本的平均 log-ratio 作为策略与参考模型之间的 KL 散度估计。
  • .clamp(min=0) 防止负 KL(数值误差)。
  • 如果 calculate_KL=False,则 kl=0

✅ 2. Chosen 样本的损失

chosen_logratios = policy_chosen_logps - reference_chosen_logps
  • 对数优势比(log odds):>0 说明策略模型给 chosen 样本更高的概率。
🔹 KTO 模式(论文公式 7)
chosen_losses = 1 - sigmoid(β * (chosen_logratios - kl))
  • Prospect Theory 中的价值函数 换成 sigmoid 形式
  • kl 作为 锚点,防止策略过度偏离参考模型。
🔹 APO unpaired 模式
chosen_losses = 1 - sigmoid(β * chosen_logratios)
  • 去掉 kl 锚点,直接让 chosen 样本尽量优于参考模型

✅ 3. Rejected 样本的损失

rejected_logratios = policy_rejected_logps - reference_rejected_logps
🔹 KTO 模式
rejected_losses = 1 - sigmoid(β * (kl - rejected_logratios))
  • 同样使用 sigmoid 形式,但把 kl 放在前面,鼓励策略模型给 rejected 更低概率
🔹 APO unpaired 模式
rejected_losses = sigmoid(β * rejected_logratios)
  • 直接让 rejected 样本的概率 低于 参考模型。

✅ 4. 奖励计算

chosen_rewards = β * chosen_logratios.detach()
rejected_rewards = β * rejected_logratios.detach()
  • 把 log-ratio 乘上温度 β,作为 奖励信号 供日志或 RLHF 使用。

✅ 5. 空样本保护

if policy_chosen_logps.shape[0] == 0:chosen_losses = torch.Tensor([]).to(device)
  • 防止 空张量,避免 accelerator.gather 在分布式环境下挂起。

✅ 6. 最终损失拼接

losses = torch.cat((self.desirable_weight * chosen_losses,self.undesirable_weight * rejected_losses),0,
)
  • chosenrejected 样本分别加权(desirable_weight, undesirable_weight)。
  • 返回值:
    • losses:每个样本的 KTO/APO 损失(后续 .nanmean())。
    • chosen_rewards, rejected_rewards:用于日志。
    • kl:策略偏离参考模型的程度。

✅ 一句话总结

这个函数把 Prospect Theory 的思想(损失厌恶、锚定效应)编码进 sigmoid 形式,通过 log-ratio 计算 chosen/rejected 样本的 损失与奖励,并兼容 KTO 原始公式APO 的无对比版本

论文参考:KTO: Model Alignment as Prospect Theoretic Optimization

http://www.dtcms.com/a/333948.html

相关文章:

  • PMP-项目管理-十大知识领域:成本管理-估算预算、控制成本、避免超支
  • 免费下载 Landsat 系列遥感影像——地理空间数据云
  • 《吃透 C++ 类和对象(中):const 成员函数与取地址运算符重载解析》
  • ALBEF/BLIP/BLIP2/Instruct BLIP/X Instruct BLIP
  • 从废弃到珍宝——旧物二手回收小程序系统的价值发现之旅
  • 曲面/线 拟合gnuplot
  • 新手向:Python列表、元组、集合和字典的用法对比
  • 谷歌手机刷机和面具ROOT保姆级别教程
  • 基于 LoRA的广义知识蒸馏(GKD)训练
  • 软考 系统架构设计师系列知识点之杂项集萃(125)
  • 给纯小白的 Python 操作 Excel 笔记
  • STM32 延时函数详解
  • HackMyVM-Uvalde
  • 第七十五章:AI的“思维操控师”:Prompt变动对潜在空间(Latent Space)的影响可视化——看懂AI的“微言大义”!
  • 整体设计 符号学与诠释学融合的整体设计框架(本篇暂时命名)--PromptPilot (助手)答问之1
  • 第四章:大模型(LLM)】06.langchain原理-(5)LangChain Prompt 用法
  • PowerPoint和WPS演示放映PPT时如何禁止鼠标翻页
  • [1Prompt1Story] 注意力机制增强 IPCA | 去噪神经网络 UNet | U型架构分步去噪
  • 国产之光时空克隆:功能对标谷歌地球,旅游规划还能加载倾斜摄影模型,三维视频融合 免费使用
  • GaussDB 数据库架构师修炼(十三)安全管理(3)-行级访问控制
  • 【C++】C++11
  • implement copy file content to clipboard on Windows
  • spring-ai-alibaba 学习(二十六)——graph总结
  • 超越“调参”:从系统架构师视角,重构 AI 智能体的设计范式
  • 玩转云原生,使用k9s管理k8s集群和k3s集群
  • 基本电子元件:金属氧化膜电阻器
  • PostgreSQL 时间函数及格式类型
  • 【机器学习深度学习】OpenCompass:支持的开源评估数据集及使用差异
  • [CSP-J2020] 方格取数
  • [1Prompt1Story] 生成行为控制器 | 语义向量重加权(SVR)