当前位置: 首页 > news >正文

大模型对齐算法(二): TDPO(Token-level Direct Preference Optimization)

TDPO(Token-level Direct Preference Optimization)


1. 研究背景

痛点说明
DPO 句子级 KL只在完整回答上算 KL,无法细粒度控制逐 token 的偏离。
KL 增长失衡图1 显示,DPO 在 dis-preferred 回答上的 SeqKL 增长更快,导致分布差异越拉越大。
多样性下降反向 KL 的“mode-seeking”特性限制了生成多样性。

2. TDPO 核心思想

  1. 把 RLHF 任务拆成 token-level MDP
  • 状态:prompt + 已生成的 token
  • 动作:下一个 token
  • 奖励:rt=R([x,y<t],yt)r_t = R([x, y_{<t}], y_t)rt=R([x,y<t],yt)
  1. 用 Bellman 方程把句子奖励
    r(x,y)=∑t=1Tγt−1rtr(x, y)=\sum_{t=1}^{T}\gamma^{t-1} r_tr(x,y)=t=1Tγt1rt

  2. 在 token 上同时引入

  • 反向 KL(防止整体偏离)
  • 正向 SeqKL(抑制 dis-preferred 回答的 KL 暴涨)

3. 关键公式速览

名称LaTeX说明
token 级目标max⁡πθEx,y<t,z[Aπref−βDKL(πθ∥πref)]\max_{\pi_\theta} \mathbb{E}_{x,y_{<t},z} \Big[A^{\pi_{\text{ref}}} - \beta D_{\text{KL}}\bigl(\pi_\theta\|\pi_{\text{ref}}\bigr)\Big]maxπθEx,y<t,z[AπrefβDKL(πθπref)]TRPO 风格
最优策略π∗(z,y<t)∝πref(z,y<t)exp⁡(1βQπref(y<t,z))\pi^*(z, y_{<t}) \propto \pi_{\text{ref}}(z, y_{<t})\exp\Bigl(\tfrac{1}{\beta}Q^{\pi_{\text{ref}}}(y_{<t},z)\Bigr)π(z,y<t)πref(z,y<t)exp(β1Qπref(y<t,z))
BT-token 模型P(y1≻y2,x)=σ ⁣(u−δ)P(y_1\succ y_2, x)=\sigma\!\bigl(u-\delta\bigr)P(y1y2,x)=σ(uδ)
uuuδ\deltaδu=βlog⁡πθ(y1)πref(y1)−βlog⁡πθ(y2)πref(y2)u=\beta\log\frac{\pi_\theta(y_1)}{\pi_{\text{ref}}(y_1)}-\beta\log\frac{\pi_\theta(y_2)}{\pi_{\text{ref}}(y_2)}u=βlogπref(y1)πθ(y1)βlogπref(y2)πθ(y2)
δ=βDSeqKL(y2)−βDSeqKL(y1)\delta=\beta D_{\text{SeqKL}}(y_2)-\beta D_{\text{SeqKL}}(y_1)δ=βDSeqKL(y2)βDSeqKL(y1)
奖励差+KL差

4. 损失函数

论文给出 两个版本

版本公式特色
TDPO1−Elog⁡σ ⁣(u−δ)-\mathbb{E}\log\sigma\!\bigl(u-\delta\bigr)Elogσ(uδ)双向 KL 同时约束
TDPO2−Elog⁡σ ⁣(u−αδ2)-\mathbb{E}\log\sigma\!\bigl(u-\alpha\delta_2\bigr)Elogσ(uαδ2)
δ2\delta_2δ2stop-gradient 保护 preferred KL
防止 preferred 回答 KL 被拉高

5. 实验结果

数据集指标结论
IMDbReward vs SeqKL FrontierTDPO2 优于 DPO、f-DPO,更高奖励 + 更低 KL
Anthropic-HH对齐准确率 & 熵TDPO2 同时提升 准确率 67.3%熵 4.915
MT-BenchGPT-4 打分TDPO2 vs DPO:60.4% 胜 28.8% 平 10.8% 负

6. 代码片段(PyTorch)

def tdpo_loss(pi_logits, ref_logits, yw_idxs, yl_idxs,labels, beta=0.1, alpha=0.5, if_tdpo2=True):pi_logp = pi_logits.log_softmax(-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1)ref_logp = ref_logits.log_softmax(-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1)# per-token KLkl = (ref_logits.softmax(-1) * (ref_logits.log_softmax(-1) - pi_logits.log_softmax(-1))).sum(-1)yw_kl, yl_kl = kl[yw_idxs], kl[yl_idxs]u = beta * (pi_logp[yw_idxs] - ref_logp[yw_idxs]) \- beta * (pi_logp[yl_idxs] - ref_logp[yl_idxs])if if_tdpo2:delta = beta * yl_kl - beta * yw_kl.detach()else:delta = beta * yl_kl - beta * yw_klloss = -F.logsigmoid(u - alpha * delta)return loss

7. 总结

TDPO = 把 DPO 的“句子级 KL”拆成“token 级 KL”,再叠一个正向 SeqKL 差分,既对齐人类偏好,又压住 KL 暴涨,实验全面优于 DPO & PPO。

8. 附录

在这里插入图片描述
公式1是怎么推导出公式2的呢?
在这里插入图片描述

✅ 1 写出带 KL 的目标(公式 1)

max⁡πθ  Ex∼D  Ey∼πθ(⋅∣x)  [r(x,y)−βDKL(πθ(⋅∣x)∥πref(⋅∣x))]\max_{\pi_\theta}\; \mathbb{E}_{x\sim\mathcal{D}}\; \mathbb{E}_{y\sim\pi_\theta(\cdot|x)}\; \Bigl[r(x,y)-\beta D_{\mathrm{KL}}\bigl(\pi_\theta(\cdot|x)\|\pi_{\mathrm{ref}}(\cdot|x)\bigr)\Bigr]maxπθExDEyπθ(x)[r(x,y)βDKL(πθ(x)πref(x))]


✅ 2 把 KL 写成期望形式

DKL(πθ∥πref)=Ey∼πθ[log⁡πθ(y∣x)πref(y∣x)]D_{\mathrm{KL}}\bigl(\pi_\theta\|\pi_{\mathrm{ref}}\bigr) =\mathbb{E}_{y\sim\pi_\theta}\Bigl[\log\frac{\pi_\theta(y|x)}{\pi_{\mathrm{ref}}(y|x)}\Bigr]DKL(πθπref)=Eyπθ[logπref(yx)πθ(yx)]

代入后得到

max⁡πθ  Ex,y∼πθ[r(x,y)−βlog⁡πθ(y∣x)πref(y∣x)]\max_{\pi_\theta}\; \mathbb{E}_{x,y\sim\pi_\theta}\Bigl[ r(x,y)-\beta\log\frac{\pi_\theta(y|x)}{\pi_{\mathrm{ref}}(y|x)} \Bigr]maxπθEx,yπθ[r(x,y)βlogπref(yx)πθ(yx)]


✅ 3 用变分法求最优策略

令目标函数为

J(π)=Ex,y∼π[r(x,y)−βlog⁡π(y∣x)πref(y∣x)]J(\pi)=\mathbb{E}_{x,y\sim\pi}\Bigl[r(x,y)-\beta\log\frac{\pi(y|x)}{\pi_{\mathrm{ref}}(y|x)}\Bigr]J(π)=Ex,yπ[r(x,y)βlogπref(yx)π(yx)]

在约束

∑yπ(y∣x)=1\sum_y \pi(y|x)=1yπ(yx)=1

下对 (π\piπ) 做 拉格朗日乘子法,得到

π∗(y∣x)∝πref(y∣x)exp⁡ ⁣(1βr(x,y))\pi^*(y|x)\propto\pi_{\mathrm{ref}}(y|x)\exp\!\bigl(\tfrac{1}{\beta}r(x,y)\bigr)π(yx)πref(yx)exp(β1r(x,y))

归一化后

π∗(y∣x)=πref(y∣x)exp⁡ ⁣(1βr(x,y))∑y′πref(y′∣x)exp⁡ ⁣(1βr(x,y′))\pi^*(y|x)=\frac{\pi_{\mathrm{ref}}(y|x)\exp\!\bigl(\tfrac{1}{\beta}r(x,y)\bigr)}{\sum_{y'}\pi_{\mathrm{ref}}(y'|x)\exp\!\bigl(\tfrac{1}{\beta}r(x,y')\bigr)}π(yx)=yπref(yx)exp(β1r(x,y))πref(yx)exp(β1r(x,y))


✅ 4 反解奖励函数(得到公式 2)

把上式两边取对数并乘 (β\betaβ):

βlog⁡π∗(y∣x)πref(y∣x)=βlog⁡1Z(x)+r(x,y)\beta\log\frac{\pi^*(y|x)}{\pi_{\mathrm{ref}}(y|x)} =\beta\log\frac{1}{Z(x)}+r(x,y)βlogπref(yx)π(yx)=βlogZ(x)1+r(x,y)

其中

Z(x)=∑y′πref(y′∣x)exp⁡ ⁣(1βr(x,y′))Z(x)=\sum_{y'}\pi_{\mathrm{ref}}(y'|x)\exp\!\bigl(\tfrac{1}{\beta}r(x,y')\bigr)Z(x)=yπref(yx)exp(β1r(x,y))

因此

r(x,y)=βlog⁡π∗(y∣x)πref(y∣x)+βlog⁡Z(x)r(x,y)=\beta\log\frac{\pi^*(y|x)}{\pi_{\mathrm{ref}}(y|x)}+\beta\log Z(x)r(x,y)=βlogπref(yx)π(yx)+βlogZ(x)

这就是公式 (2)。

http://www.dtcms.com/a/336443.html

相关文章:

  • Android中使用Compose实现各种样式Dialog
  • tcp会无限次重传吗
  • Eclipse Tomcat Configuration
  • Portkey-AI gateway 的一次“假压缩头”翻车的完整排障记:由 httpx 解压异常引发的根因分析
  • 学习日志36 python
  • 力扣经典算法篇-52-零钱兑换(动态规划)
  • Java语法进阶之常用类
  • 【C2000】德州仪器C2000产品整体介绍
  • http工作流程
  • LangChain 多任务应用开发
  • matlab tlc的文件、字符串操作
  • Python @staticmethod 装饰器与 staticmethod() 函数
  • Tomcat Session Replication Cluster:实现高可用性和可扩展性的关键
  • 机试备考笔记 14/31
  • Ugit使用记录
  • Next.js跟React关系(Next.js是基于React库的全栈框架)(文件系统路由、服务端渲染SSR、静态生成SSG、增量静态再生ISR、API路由)
  • 提升 LLM 推理效率的秘密武器:LM Cache 架构与实践
  • Pandas初学者入门
  • C语言中回调函数的作用
  • 2025.8.11-2025.8.17第33周:完成第一次头马备稿演讲
  • 北京JAVA基础面试30天打卡12
  • 【URP】[法线贴图]为什么主要是蓝色的?
  • ZipList优缺点总结
  • leetcode_438 找到字符串中的所有异位词
  • 代码随想录刷题Day34
  • 上位机知识篇---静态库
  • 计算机网络 TCP 延迟确认机制
  • SpringCloud 01 分布式系统
  • 自由学习记录(85)
  • 【k8s、docker】Headless Service(无头服务)