huggingface TRL中的对齐算法: KTO
KTO
1 背景
• 传统 RLHF / DPO / IPO 都需要「成对偏好」数据 (x, y_w, y_l)。
• 真实业务日志往往只有单条样本加二值标签(good / bad),收集成对数据成本高且周期长。
• KTO 受前景理论(Prospect Theory)启发,用「收益-损失」框架刻画人类对单条结果的主观效用,从而支持仅二元标签即可离线对齐大模型 。
──────────────────
2 算法流程(离线单轨迹)
输入
• 数据集 D = {(x, y, b)} ,其中 b = 1 表示“好”,b = 0 表示“坏”。
• 参考策略 π_ref(通常由 SFT 模型初始化)。
步骤
-
对每个样本计算
r_θ(x, y) = log π_θ(y|x) – log π_ref(y|x) (对数似然比) -
估计参考点(baseline)
z₀ = KL(π_θ||π_ref) 的滑动窗口近似,实践中常用 mini-batch 平均 r_θ 代替 。 -
根据 y 的类别计算前景价值 v(x, y):
v(x, y) =
┌ λ_D · σ(β (r_θ – z₀)) 若 b = 1 (好样本)
└ λ_U · σ(β (z₀ – r_θ)) 若 b = 0 (坏样本)σ 为 Logistic 函数,β > 0 控制敏感度,λ_D、λ_U 为正负样本权重(通常 λ_U > λ_D 以反映损失厌恶)。
-
最小化人类对齐损失目标(HALO):
L_KTO = – 𝔼_D [ v(x, y) ]
-
反向传播更新 π_θ 参数即可,无需显式奖励模型,也无需在线采样。
──────────────────
3 关键公式汇总
符号 | 说明 | 来源 |
---|---|---|
r_θ(x,y) = log π_θ(y | x) – log π_ref(y | x) |
z₀ = 𝔼[ r_θ(x, y′) ] ≈ KL(π_θ | π_ref) | |
v(x,y)=λ_D σ(β(r_θ−z₀)) (好) | 前景价值(收益段) | (6) |
v(x,y)=λ_U σ(β(z₀−r_θ)) (坏) | 前景价值(损失段) | (6) |
L_KTO = − 𝔼_D[ v(x,y) ] | 最终损失 | (2) |
──────────────────
4 与 DPO、PPO 的对比
维度 | KTO | DPO | PPO |
---|---|---|---|
数据需求 | 单条+二元标签 | 成对偏好 | 在线轨迹 |
奖励模型 | 无(隐式) | 无(隐式) | 需显式 RM |
训练方式 | 离线 | 离线 | 在线 RL |
核心思想 | 前景理论效用 | Bradley-Terry 偏好排序 | 最大化奖励+KL 约束 |
──────────────────
5 工程实现要点
- β 默认 0.1;λ_U 设成 λ_D 的 2–3 倍以匹配人类损失厌恶系数 。
- z₀ 每 N 步更新一次即可,避免训练不稳定。
- 若正负样本不平衡,可把 λ_D、λ_U 改成基于类别频率的逆比例权重。
至此,KTO 算法的完整背景、推导和公式已全部列出,可直接用于代码实现或论文写作。
6. TRL里怎么实现的
下面逐行拆解这个 get_batch_loss_metrics
函数,说明它在做什么、为什么要这么做,以及可能的注意点。
1. 函数签名
def get_batch_loss_metrics(self,model,batch: dict[str, Union[list, torch.LongTensor]],):
- 目的:给定一个 batch 的输入,计算 KTO(Kahneman-Tversky Optimization)损失 及其附属指标。
- 输入:
model
:当前策略网络(policy model)。batch
:字典,字段通常包括:input_ids
,attention_mask
,label
, 可能还有reference_logps
,reference_KL_logps
等。
2. 设备搬运
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v)for k, v in batch.items()}
- 把所有张量搬到 accelerator 所在的设备(GPU/CPU)。
3. 统计 chosen / rejected 样本数
labels = torch.tensor(batch["label"])
num_chosen = labels.sum()
num_rejected = len(labels) - num_chosen
label==1
→ chosen(人类偏好较高的回答)。label==0
→ rejected。
4. 损失计算分支:use_liger_loss
if self.args.use_liger_loss:...
else:...
liger
分支:
使用 Liger-Kernel 的高度融合 CUDA kernel,一次前向即可拿到所有所需张量(logps
、logits
、rewards
、KL
)。else
分支:
传统做法:先跑 forward,再跑 reference 模型,再算损失。
5. Reference Logps 的两种来源
if "reference_logps" in batch:...
else:with torch.no_grad():...
- 离线缓存(推荐):训练前就预先把参考模型的
log p(y|x)
算好,避免每个 step 都跑一次参考模型,速度大幅提升。 - 实时推理:没有离线缓存时,用
ref_model
(或model
开启null_ref_context
)再跑一次 forward。
6. 计算 KTO Loss
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(...)
- 核心函数,实现 KTO 论文里的 prospect-theoretic utility 公式:
chosen_rewards = policy_chosen_logps - reference_chosen_logps
rejected_rewards = policy_rejected_logps - reference_rejected_logps
- KL 项控制与参考策略的偏离度。
7. 分布式指标收集
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
...
- 在多卡/多节点环境下,把所有卡的 chosen、rejected 计数、reward、logp 累加,保证指标 全局一致。
8. 最终损失 + Aux Loss
loss = losses.nanmean()
if self.aux_loss_enabled:loss += self.aux_loss_coef * aux_loss
- nanmean:忽略 NaN 的样本(可能出现空 chosen 或空 rejected)。
- 如有 辅助损失(比如 MoE load balancing loss),加到总损失。
9. 返回值
return loss, metrics
loss
:用于反向传播的张量。metrics
:字典,记录:kl
:平均 KL。rewards/chosen_sum
,logps/chosen_sum
,count/chosen
(chosen 样本)。rewards/rejected_sum
,logps/rejected_sum
,count/rejected
(rejected 样本)。
✅ 一句话总结
这个函数在 Deepspeed / Accelerate 分布式环境下,统一完成了:
- 数据搬运 →
- 策略模型/参考模型前向 →
- 计算 KTO loss + 指标 →
- 跨进程聚合 →
- 返回可反向传播的 loss 和日志。
逻辑清晰,兼顾了 性能(liger kernel / 离线缓存) 与 分布式正确性。
这个 kto_loss
函数是 KTO(Kahneman-Tversky Optimization)损失的核心实现,也顺带兼容了 APO(Anchored Preference Optimization) 的 unpaired 变体。下面逐段拆解:
✅ 1. KL 估计(可选)
if self.calculate_KL:kl = (policy_KL_logps - reference_KL_logps).mean().detach()kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
- 用 整批样本的平均 log-ratio 作为策略与参考模型之间的 KL 散度估计。
.clamp(min=0)
防止负 KL(数值误差)。- 如果
calculate_KL=False
,则kl=0
。
✅ 2. Chosen 样本的损失
chosen_logratios = policy_chosen_logps - reference_chosen_logps
- 对数优势比(log odds):>0 说明策略模型给 chosen 样本更高的概率。
🔹 KTO 模式(论文公式 7)
chosen_losses = 1 - sigmoid(β * (chosen_logratios - kl))
- 把 Prospect Theory 中的价值函数 换成 sigmoid 形式。
kl
作为 锚点,防止策略过度偏离参考模型。
🔹 APO unpaired 模式
chosen_losses = 1 - sigmoid(β * chosen_logratios)
- 去掉
kl
锚点,直接让 chosen 样本尽量优于参考模型。
✅ 3. Rejected 样本的损失
rejected_logratios = policy_rejected_logps - reference_rejected_logps
🔹 KTO 模式
rejected_losses = 1 - sigmoid(β * (kl - rejected_logratios))
- 同样使用 sigmoid 形式,但把
kl
放在前面,鼓励策略模型给 rejected 更低概率。
🔹 APO unpaired 模式
rejected_losses = sigmoid(β * rejected_logratios)
- 直接让 rejected 样本的概率 低于 参考模型。
✅ 4. 奖励计算
chosen_rewards = β * chosen_logratios.detach()
rejected_rewards = β * rejected_logratios.detach()
- 把 log-ratio 乘上温度
β
,作为 奖励信号 供日志或 RLHF 使用。
✅ 5. 空样本保护
if policy_chosen_logps.shape[0] == 0:chosen_losses = torch.Tensor([]).to(device)
- 防止 空张量,避免
accelerator.gather
在分布式环境下挂起。
✅ 6. 最终损失拼接
losses = torch.cat((self.desirable_weight * chosen_losses,self.undesirable_weight * rejected_losses),0,
)
- 给 chosen 和 rejected 样本分别加权(
desirable_weight
,undesirable_weight
)。 - 返回值:
losses
:每个样本的 KTO/APO 损失(后续.nanmean()
)。chosen_rewards
,rejected_rewards
:用于日志。kl
:策略偏离参考模型的程度。
✅ 一句话总结
这个函数把 Prospect Theory 的思想(损失厌恶、锚定效应)编码进 sigmoid 形式,通过 log-ratio 计算 chosen/rejected 样本的 损失与奖励,并兼容 KTO 原始公式 和 APO 的无对比版本。
论文参考:KTO: Model Alignment as Prospect Theoretic Optimization