知识蒸馏 Knowledge Distillation 0. 基础:自回归分解与逐 token散度
知识蒸馏 Knowledge Distillation 0. 基础:自回归分解与逐 token散度
代码实践
论文 Generalized Knowledge Distillation (GKD)
On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes
序列级散度的逐 token 分解公式 或 token-level distillation loss。
在论文语境下,它是 教师分布与学生分布在序列 y 上的散度定义。
联合概率、条件概率、边缘概率
乘法法则、全概率公式、贝叶斯定理
概率链式法则(Probability Chain Rule)
序列的联合概率 分解成 基于历史的条件概率的连乘序列
公式
D(pT∥pSθ)(y∣x):=1Ly∑n=1LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T \parallel p_S^{\theta}\right)(y|x) \;:=\; \frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right) D(pT∥pSθ)(y∣x):=Ly1n=1∑LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))
符号 :=
在数学、逻辑和计算机科学中,符号 “:=” 表示 “定义为”(defined as),它的作用是给左边的新符号/表达式赋予一个明确的含义,强调“左边的内容是由右边的内容来定义的”。
为什么不用“=”(等于号)?
等于号“=”表示的是**“等价关系”**——即两边的表达式在数值、逻辑或含义上是相等的,且这种相等是基于已有的定义推导出来的。例如:
- “2+3=52+3=52+3=5” 表示“2+32+32+3的结果等价于555”(基于加法的已有定义);
- “x=2yx=2yx=2y” 表示“xxx和2y2y2y在数值上相等”(xxx和yyy的含义已明确)。
而“:=”则专门用于引入新符号并规定其含义。它强调:
左边的符号是第一次出现的新符号,需要通过右边的内容来“创造”它的含义。
公式中:
D(pT∥pSθ)(y∣x):=1Ly∑n=1LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T \parallel p_S^{\theta}\right)(y|x) \;:=\; \frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right) D(pT∥pSθ)(y∣x):=Ly1n=1∑LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))
这里左边的 D(pT∥pSθ)(y∣x)D(p_T \parallel p_S^\theta)(y|x)D(pT∥pSθ)(y∣x) 是一个新引入的符号(表示“给定xxx和yyy时,教师与学生分布的序列级散度”),之前没有被定义过。因此用“:=”明确:“把左边这个新符号定义为右边的平均散度表达式”。
如果换成“=”,会产生歧义:读者可能会误以为 D(pT∥pSθ)(y∣x)D(p_T \parallel p_S^\theta)(y|x)D(pT∥pSθ)(y∣x) 是一个已经被定义过的符号,现在只是在陈述它和右边表达式“相等”——但实际上这个符号是第一次出现,需要明确它的定义。
符号 ∥\parallel∥
公式中的 ∥\parallel∥ 是散度(divergence)表示中的专用符号,用于分隔两个需要比较的概率分布,表示“衡量前者与后者之间的差异”。
在概率统计和信息论中,当讨论两个概率分布 ppp 和 qqq 的散度(比如KL散度、JS散度等)时,通常写作 D(p∥q)D(p \parallel q)D(p∥q),这里的∥\parallel∥可以理解为“相对于”或“与…比较”,整个表达式的含义是“分布 ppp 相对于分布 qqq 的散度”。
在 D(pT∥pSθ)D\!\left(p_T \parallel p_S^{\theta}\right)D(pT∥pSθ) 中:
- pTp_TpT 是教师模型的概率分布,
- pSθp_S^\thetapSθ 是学生模型的概率分布,
- ∥\parallel∥ 分隔了这两个分布,表明散度 DDD 正在衡量“教师分布 pTp_TpT 与学生分布 pSθp_S^\thetapSθ 之间的差异”。
这个符号和几何中“平行”(比如直线 a∥ba \parallel ba∥b)的含义完全无关,它是概率散度中的专用符号,仅用于明确“被比较的两个分布”的顺序(因为散度通常不对称,D(p∥q)≠D(q∥p)D(p \parallel q) \neq D(q \parallel p)D(p∥q)=D(q∥p),所以顺序很重要)。
完整公式
将论文中序列级散度的逐token分解公式,从外层到内层,结合"输入x=x=x=‘用5个词描述苹果:’、输出y=y=y=‘苹果是红色的水果’"的场景来说
按"外层→中层→内层"拆解:
D(pT⏞内层1∥pSθ⏞内层2)(y∣x)⏞外层条件⏟外层:序列级散度:=⏟定义符号1Ly⏟中层:归一化∑n=1Ly⏟中层:求和D(pT(⋅∣y<n⏞内层4,x⏞内层5)⏞内层3∥pSθ(⋅∣y<n,x)⏞内层6)⏟内层:逐token散度\underbrace{D\!\left(\overbrace{p_T}^{\text{内层1}} \parallel \overbrace{p_S^{\theta}}^{\text{内层2}}\right)\overbrace{(y|x)}^{\text{外层条件}}}_{\text{外层:序列级散度}} \;\underbrace{:=\;}_{\text{定义符号}} \;\underbrace{\frac{1}{L_y}}_{\text{中层:归一化}} \underbrace{\sum_{n=1}^{L_y}}_{\text{中层:求和}} \underbrace{D\!\left(\overbrace{p_T(\cdot \mid \overbrace{y_{<n}}^{\text{内层4}}, \overbrace{x}^{\text{内层5}})}^{\text{内层3}} \parallel \overbrace{p_S^{\theta}(\cdot \mid y_{<n},x)}^{\text{内层6}}\right)}_{\text{内层:逐token散度}} 外层:序列级散度DpT内层1∥pSθ内层2(y∣x)外层条件定义符号:=中层:归一化Ly1中层:求和n=1∑Ly内层:逐token散度DpT(⋅∣y<n内层4,x内层5)内层3∥pSθ(⋅∣y<n,x)内层6
层级 | 符号 | 类型 | 核心含义 | 苹果场景对应示例 |
---|---|---|---|---|
外层 | D(⋅∥⋅)D(\cdot \parallel \cdot)D(⋅∥⋅) | 散度运算符 | 衡量两个概率分布的差异(非对称) | D(p教师∥p学生)D(p_{\text{教师}} \parallel p_{\text{学生}})D(p教师∥p学生) |
外层 | (y∣x)(y|x)(y∣x) | 条件限定符 | 限定散度计算在"给定输入xxx和序列yyy"的场景下 | 给定x=x=x="描述苹果:"和y=y=y=“苹果是红色的水果” |
外层 | :=:=:= | 定义符号 | 左边由右边定义 | 序列级散度 :=:=:= 逐token散度的平均 |
中层 | 1Ly\frac{1}{L_y}Ly1 | 归一化因子 | 对散度总和进行平均,消除序列长度偏差 | Ly=5L_y=5Ly=5,故15\frac{1}{5}51 |
中层 | ∑n=1Ly\sum_{n=1}^{L_y}∑n=1Ly | 求和运算符 | 遍历所有token位置,累加逐token散度 | ∑n=15\sum_{n=1}^5∑n=15(累加5个token的散度) |
中层 | nnn | 位置索引 | 表示当前处理的token序号(1到LyL_yLy) | n=3n=3n=3对应第3个token"红色" |
中层 | LyL_yLy | 序列长度 | 输出序列yyy包含的token数量 | yyy有5个token,故Ly=5L_y=5Ly=5 |
内层 | pTp_TpT | 教师分布 | 教师模型的固定概率分布(模仿目标) | 教师对"红色"的概率是0.7 |
内层 | pSθp_S^\thetapSθ | 学生分布 | 学生模型的可训练概率分布(θ\thetaθ是模型参数) | 初始时学生对"红色"的概率是0.4 |
内层 | ⋅\cdot⋅ | 全集占位符 | 代表词表中所有可能的下一个token | “红色”“绿色”"蓝色"等所有描述性token |
内层 | y<ny_{<n}y<n | 序列前缀 | 第nnn个token之前的所有token组成的子序列 | n=3n=3n=3时,y<3=y_{<3}=y<3=“苹果是” |
内层 | xxx | 输入变量 | 任务指令/prompt(预测的前提) | x=x=x=“用5个词描述苹果:” |
内层 | ∣\mid∣ | 条件分隔符 | 表示"在…条件下"(竖线右边是前提) | pT(⋅∣"苹果是",x)p_T(\cdot \mid "苹果是",x)pT(⋅∣"苹果是",x)即"给定前缀和输入" |
逐符号拆解(从外到内)
(一)外层符号:序列级散度的整体框架
-
D(⋅∥⋅)D(\cdot \parallel \cdot)D(⋅∥⋅)
衡量两个概率分布的差异程度,差异越大值越大,分布完全相同时为0。括号内"⋅∥⋅\cdot \parallel \cdot⋅∥⋅"是固定格式,左边为基准分布、右边为待比较分布,且不满足对称性(即D(A∥B)≠D(B∥A)D(A \parallel B) \neq D(B \parallel A)D(A∥B)=D(B∥A))。常用类型有KL散度(衡量学生近似教师的信息损失)、JSD散度(对称且有界),比如在苹果场景中,就是衡量教师对"描述苹果"的预测分布和学生预测分布的差距。 -
D(pT∥pSθ)D\left(p_T \parallel p_S^{\theta}\right)D(pT∥pSθ)
表示教师分布pTp_TpT与学生分布pSθp_S^\thetapSθ之间的散度,是未限定具体序列和输入时的泛化定义。比如泛化地描述"教师对‘描述苹果’的预测,和学生对‘描述苹果’的预测之间的差距",此时还未指定具体生成哪句话。 -
(y∣x)(y|x)(y∣x)
将散度计算限定在"给定输入xxx和输出序列yyy"的场景下,类似条件概率p(y∣x)p(y|x)p(y∣x)(给定xxx生成yyy的概率),这里是"给定xxx和yyy的散度",而非对所有xxx和yyy的期望散度。在苹果场景中,就是"给定输入‘用5个词描述苹果:’、输出序列‘苹果是红色的水果’时",限定了散度的计算范围。 -
:=:=:=
代表"定义",表示左边的符号由右边的公式定义,即"序列级散度"这个概念,被定义为右边"逐token散度求和再平均"的结果。区别于普通等于号===(===表示左右数值相等,:=:=:=表示左边是右边的缩写或定义),比如用"序列级散度 :=:=:= 逐token散度的平均",明确两者的定义关系。
(二)中层符号:序列级到逐token级的拆解
-
1Ly\frac{1}{L_y}Ly1
对"逐token散度的总和"进行归一化,避免长序列散度总和更大、短序列更小导致的训练偏差。其中LyL_yLy是序列yyy的token数量(必须为正整数),在苹果场景中,yyy有5个token,Ly=5L_y=5Ly=5,1Ly\frac{1}{L_y}Ly1就是15\frac{1}{5}51,用于将5个token的散度总和平均为"每个token的平均差距"。 -
∑n=1Ly\sum_{n=1}^{L_y}∑n=1Ly
遍历序列yyy的所有token位置,将每个位置的"逐token散度"相加得到总散度。其中∑\sum∑是求和符号,nnn是位置索引(循环变量),n=1n=1n=1是求和起始位置(对应苹果场景中第一个token"苹果"),LyL_yLy是求和终止位置(对应苹果场景中第五个token"水果")。比如在苹果场景中,就是把第1到第5个token的逐token散度全部加起来。
(三)内层符号:逐token散度的核心细节
内层是公式核心,对应"每个token位置上师生分布的具体差异",即D(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right)D(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))。
-
pTp_TpT
TTT是"Teacher(教师)“的首字母,代表教师模型的概率分布,是训练成熟的大模型,预测结果更准确,作为学生的"模仿目标”,且训练过程中分布固定。在苹果场景中,就是大模型对"描述苹果"的预测分布,比如生成"红色"时,教师给"红色"的概率是0.7,给"蓝色"的概率接近0。 -
pSθp_S^{\theta}pSθ
SSS是"Student(学生)"的首字母,代表待蒸馏的小模型的概率分布;θ\thetaθ(theta)是学生模型的可训练参数(如神经网络的权重、偏置),训练中通过调整θ\thetaθ,让pSθp_S^\thetapSθ逐渐接近pTp_TpT。在苹果场景中,初始时θ\thetaθ未优化,学生给"红色"的概率只有0.4,训练后θ\thetaθ更新,这个概率会接近教师的0.7。 -
⋅\cdot⋅(点号)
代表"词表中所有可能的下一个token",明确散度比较的是"完整的概率分布",而非"某个具体token的概率"。比如在苹果场景中,n=3n=3n=3(前缀"苹果是")时,⋅\cdot⋅代表"红色"“绿色”“蓝色”"圆形"等所有可能描述苹果特征的token,pT(⋅∣"苹果是",x)p_T(\cdot \mid "苹果是",x)pT(⋅∣"苹果是",x)就是教师对这些token的完整概率分布("红色"0.7、"绿色"0.2、"蓝色"0.01…)。 -
y<ny_{<n}y<n
表示"序列yyy的前缀",即第nnn个token之前所有token组成的子序列。其中yyy是当前讨论的输出序列,<n<n<n表示"取位置1到n−1n-1n−1的token"(不包含位置nnn)。在苹果场景中,n=1n=1n=1时y<1y_{<1}y<1是空序列,n=3n=3n=3时y<3y_{<3}y<3是"苹果是",n=5n=5n=5时y<5y_{<5}y<5是"苹果是红色的"。 -
xxx
代表模型的"输入信息",通常是任务指令、prompt或上下文,所有token的预测都依赖于xxx——同样的前缀"苹果是",输入x=x=x="描述苹果:"和x=x=x="描述天空:“时,模型预测的下一个token完全不同。在苹果场景中,xxx就是"用5个词描述苹果:”,是师生生成"苹果是红色的水果"的核心依据。 -
∣\mid∣(竖线)
表示"在…的条件下",竖线右边是"预测的前提",左边是"被预测的对象",对应数学中条件概率的表示逻辑(如p(A∣B)p(A \mid B)p(A∣B)表示"给定BBB发生时AAA发生的概率")。在苹果场景中,pT(⋅∣"苹果是","描述苹果:")p_T(\cdot \mid "苹果是", "描述苹果:")pT(⋅∣"苹果是","描述苹果:")就是"给定输入‘描述苹果:’和前缀‘苹果是’时,教师对所有可能下一个token的概率分布"。
计算例子
- 输入 x="描述苹果:"x = \text{"描述苹果:"}x="描述苹果:"
- 输出序列 y=("苹果", "是", "红色", "的", "水果")y = (\text{"苹果", "是", "红色", "的", "水果"})y=("苹果", "是", "红色", "的", "水果")(长度 Ly=5L_y = 5Ly=5)
- 我们使用KL散度计算,公式为:DKL(p∥q)=∑ipilog(piqi)D_{KL}(p \parallel q) = \sum_{i} p_i \log\left(\frac{p_i}{q_i}\right)DKL(p∥q)=∑ipilog(qipi)
逐步计算过程:
-
第一步:预测第1个token “苹果”
- 前缀:y<1=()y_{<1} = ()y<1=()(空序列)
- 教师分布:pT={"苹果":0.9,"香蕉":0.03,"橙子":0.03,"其他":0.04}p_T = \{\text{"苹果"}:0.9, \text{"香蕉"}:0.03, \text{"橙子"}:0.03, \text{"其他"}:0.04\}pT={"苹果":0.9,"香蕉":0.03,"橙子":0.03,"其他":0.04}
- 学生分布:pS={"苹果":0.6,"香蕉":0.15,"橙子":0.15,"其他":0.1}p_S = \{\text{"苹果"}:0.6, \text{"香蕉"}:0.15, \text{"橙子"}:0.15, \text{"其他"}:0.1\}pS={"苹果":0.6,"香蕉":0.15,"橙子":0.15,"其他":0.1}
- 第1个位置的KL散度计算:
D1=0.9log(0.90.6)+0.03log(0.030.15)+0.03log(0.030.15)+0.04log(0.040.1)=0.9×0.405+0.03×(−1.609)+0.03×(−1.609)+0.04×(−0.916)=0.3645−0.0483−0.0483−0.0366≈0.231\begin{align*} D_1 &= 0.9\log\left(\frac{0.9}{0.6}\right) + 0.03\log\left(\frac{0.03}{0.15}\right) + 0.03\log\left(\frac{0.03}{0.15}\right) + 0.04\log\left(\frac{0.04}{0.1}\right) \\ &= 0.9 \times 0.405 + 0.03 \times (-1.609) + 0.03 \times (-1.609) + 0.04 \times (-0.916) \\ &= 0.3645 - 0.0483 - 0.0483 - 0.0366 \\ &\approx 0.231 \end{align*} D1=0.9log(0.60.9)+0.03log(0.150.03)+0.03log(0.150.03)+0.04log(0.10.04)=0.9×0.405+0.03×(−1.609)+0.03×(−1.609)+0.04×(−0.916)=0.3645−0.0483−0.0483−0.0366≈0.231
-
第二步:预测第2个token “是”
- 前缀:y<2=("苹果")y_{<2} = (\text{"苹果"})y<2=("苹果")
- 教师分布:pT={"是":0.8,"有":0.1,"像":0.05,"其他":0.05}p_T = \{\text{"是"}:0.8, \text{"有"}:0.1, \text{"像"}:0.05, \text{"其他"}:0.05\}pT={"是":0.8,"有":0.1,"像":0.05,"其他":0.05}
- 学生分布:pS={"是":0.5,"有":0.25,"像":0.15,"其他":0.1}p_S = \{\text{"是"}:0.5, \text{"有"}:0.25, \text{"像"}:0.15, \text{"其他"}:0.1\}pS={"是":0.5,"有":0.25,"像":0.15,"其他":0.1}
- 第2个位置的KL散度计算:
D2=0.8log(0.80.5)+0.1log(0.10.25)+0.05log(0.050.15)+0.05log(0.050.1)=0.8×0.470+0.1×(−0.916)+0.05×(−1.098)+0.05×(−0.693)=0.376−0.0916−0.0549−0.0347≈0.195\begin{align*} D_2 &= 0.8\log\left(\frac{0.8}{0.5}\right) + 0.1\log\left(\frac{0.1}{0.25}\right) + 0.05\log\left(\frac{0.05}{0.15}\right) + 0.05\log\left(\frac{0.05}{0.1}\right) \\ &= 0.8 \times 0.470 + 0.1 \times (-0.916) + 0.05 \times (-1.098) + 0.05 \times (-0.693) \\ &= 0.376 - 0.0916 - 0.0549 - 0.0347 \\ &\approx 0.195 \end{align*} D2=0.8log(0.50.8)+0.1log(0.250.1)+0.05log(0.150.05)+0.05log(0.10.05)=0.8×0.470+0.1×(−0.916)+0.05×(−1.098)+0.05×(−0.693)=0.376−0.0916−0.0549−0.0347≈0.195
-
第三步:预测第3个token “红色”
- 前缀:y<3=("苹果", "是")y_{<3} = (\text{"苹果", "是"})y<3=("苹果", "是")
- 教师分布:pT={"红色":0.7,"绿色":0.2,"圆形":0.05,"其他":0.05}p_T = \{\text{"红色"}:0.7, \text{"绿色"}:0.2, \text{"圆形"}:0.05, \text{"其他"}:0.05\}pT={"红色":0.7,"绿色":0.2,"圆形":0.05,"其他":0.05}
- 学生分布:pS={"红色":0.4,"绿色":0.1,"蓝色":0.3,"其他":0.2}p_S = \{\text{"红色"}:0.4, \text{"绿色"}:0.1, \text{"蓝色"}:0.3, \text{"其他"}:0.2\}pS={"红色":0.4,"绿色":0.1,"蓝色":0.3,"其他":0.2}
- 第3个位置的KL散度计算:
D3=0.7log(0.70.4)+0.2log(0.20.1)+0.05log(0.050)+0.05log(0.050.2)(注意:学生分布中"圆形"概率为0,实际计算中会用极小值ε代替避免log(∞))≈0.7×0.559+0.2×0.693+0.05×5+0.05×(−1.386)≈0.391+0.139+0.25−0.069≈0.711\begin{align*} D_3 &= 0.7\log\left(\frac{0.7}{0.4}\right) + 0.2\log\left(\frac{0.2}{0.1}\right) + 0.05\log\left(\frac{0.05}{0}\right) + 0.05\log\left(\frac{0.05}{0.2}\right) \\ &\quad \text{(注意:学生分布中"圆形"概率为0,实际计算中会用极小值ε代替避免log(∞))} \\ &\approx 0.7 \times 0.559 + 0.2 \times 0.693 + 0.05 \times 5 + 0.05 \times (-1.386) \\ &\approx 0.391 + 0.139 + 0.25 - 0.069 \\ &\approx 0.711 \end{align*} D3=0.7log(0.40.7)+0.2log(0.10.2)+0.05log(00.05)+0.05log(0.20.05)(注意:学生分布中"圆形"概率为0,实际计算中会用极小值ε代替避免log(∞))≈0.7×0.559+0.2×0.693+0.05×5+0.05×(−1.386)≈0.391+0.139+0.25−0.069≈0.711
-
第四步:预测第4个token “的”
- 前缀:y<4=("苹果", "是", "红色")y_{<4} = (\text{"苹果", "是", "红色"})y<4=("苹果", "是", "红色")
- 教师分布:pT={"的":0.9,"很":0.05,"非常":0.03,"其他":0.02}p_T = \{\text{"的"}:0.9, \text{"很"}:0.05, \text{"非常"}:0.03, \text{"其他"}:0.02\}pT={"的":0.9,"很":0.05,"非常":0.03,"其他":0.02}
- 学生分布:pS={"的":0.7,"很":0.15,"非常":0.1,"其他":0.05}p_S = \{\text{"的"}:0.7, \text{"很"}:0.15, \text{"非常"}:0.1, \text{"其他"}:0.05\}pS={"的":0.7,"很":0.15,"非常":0.1,"其他":0.05}
- 第4个位置的KL散度计算:
D4=0.9log(0.90.7)+0.05log(0.050.15)+0.03log(0.030.1)+0.02log(0.020.05)=0.9×0.251+0.05×(−1.098)+0.03×(−1.203)+0.02×(−0.916)=0.226−0.055−0.036−0.018≈0.117\begin{align*} D_4 &= 0.9\log\left(\frac{0.9}{0.7}\right) + 0.05\log\left(\frac{0.05}{0.15}\right) + 0.03\log\left(\frac{0.03}{0.1}\right) + 0.02\log\left(\frac{0.02}{0.05}\right) \\ &= 0.9 \times 0.251 + 0.05 \times (-1.098) + 0.03 \times (-1.203) + 0.02 \times (-0.916) \\ &= 0.226 - 0.055 - 0.036 - 0.018 \\ &\approx 0.117 \end{align*} D4=0.9log(0.70.9)+0.05log(0.150.05)+0.03log(0.10.03)+0.02log(0.050.02)=0.9×0.251+0.05×(−1.098)+0.03×(−1.203)+0.02×(−0.916)=0.226−0.055−0.036−0.018≈0.117
-
第五步:预测第5个token “水果”
- 前缀:y<5=("苹果", "是", "红色", "的")y_{<5} = (\text{"苹果", "是", "红色", "的"})y<5=("苹果", "是", "红色", "的")
- 教师分布:pT={"水果":0.85,"食物":0.1,"东西":0.03,"其他":0.02}p_T = \{\text{"水果"}:0.85, \text{"食物"}:0.1, \text{"东西"}:0.03, \text{"其他"}:0.02\}pT={"水果":0.85,"食物":0.1,"东西":0.03,"其他":0.02}
- 学生分布:pS={"水果":0.65,"食物":0.2,"东西":0.1,"其他":0.05}p_S = \{\text{"水果"}:0.65, \text{"食物"}:0.2, \text{"东西"}:0.1, \text{"其他"}:0.05\}pS={"水果":0.65,"食物":0.2,"东西":0.1,"其他":0.05}
- 第5个位置的KL散度计算:
D5=0.85log(0.850.65)+0.1log(0.10.2)+0.03log(0.030.1)+0.02log(0.020.05)=0.85×0.274+0.1×(−0.693)+0.03×(−1.203)+0.02×(−0.916)=0.233−0.069−0.036−0.018≈0.110\begin{align*} D_5 &= 0.85\log\left(\frac{0.85}{0.65}\right) + 0.1\log\left(\frac{0.1}{0.2}\right) + 0.03\log\left(\frac{0.03}{0.1}\right) + 0.02\log\left(\frac{0.02}{0.05}\right) \\ &= 0.85 \times 0.274 + 0.1 \times (-0.693) + 0.03 \times (-1.203) + 0.02 \times (-0.916) \\ &= 0.233 - 0.069 - 0.036 - 0.018 \\ &\approx 0.110 \end{align*} D5=0.85log(0.650.85)+0.1log(0.20.1)+0.03log(0.10.03)+0.02log(0.050.02)=0.85×0.274+0.1×(−0.693)+0.03×(−1.203)+0.02×(−0.916)=0.233−0.069−0.036−0.018≈0.110
最终序列级散度计算:
D(pT∥pSθ)(y∣x)=15(D1+D2+D3+D4+D5)=15(0.231+0.195+0.711+0.117+0.110)≈15×1.364=0.273D\left(p_T \parallel p_S^\theta\right)(y|x) = \frac{1}{5} \left(D_1 + D_2 + D_3 + D_4 + D_5\right) = \frac{1}{5} \left(0.231 + 0.195 + 0.711 + 0.117 + 0.110\right) \approx \frac{1}{5} \times 1.364 = 0.273 D(pT∥pSθ)(y∣x)=51(D1+D2+D3+D4+D5)=51(0.231+0.195+0.711+0.117+0.110)≈51×1.364=0.273
这个结果(0.273)表示在描述苹果这个任务中,学生模型与教师模型的整体差距。在训练过程中,我们会通过优化学生模型参数θ来最小化这个值,使学生的预测分布尽可能接近教师的分布。
公式与例子的对应关系:
-
外层结构 D(pT∥pSθ)(y∣x)D\!\left(p_T \parallel p_S^{\theta}\right)(y|x)D(pT∥pSθ)(y∣x)
对应例子中“最终序列级散度”的结果(≈0.273),表示“给定输入x=x=x=‘描述苹果:’和输出序列y=y=y=‘苹果是红色的水果’时,教师分布与学生分布的整体差距”。 -
1Ly\frac{1}{L_y}Ly1
例子中Ly=5L_y=5Ly=5(序列yyy有5个token),对应15\frac{1}{5}51,用于对5个位置的散度求和后做平均。 -
∑n=1Ly\sum_{n=1}^{L_y}∑n=1Ly
例子中是对n=1n=1n=1到n=5n=5n=5的散度求和(D1+D2+D3+D4+D5D_1 + D_2 + D_3 + D_4 + D_5D1+D2+D3+D4+D5),对应公式中“遍历所有token位置累加散度”的逻辑。 -
内层散度 D(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right)D(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))
例子中每个DnD_nDn(如D1D_1D1到D5D_5D5)分别对应:- n=1n=1n=1时:教师和学生在“空前缀+输入xxx”条件下的分布差异;
- n=2n=2n=2时:教师和学生在“前缀‘苹果’+输入xxx”条件下的分布差异;
- 以此类推,完全匹配公式中“逐token条件分布散度”的定义。
论文里的公式用如下表示就简单些
原来是
D(pT∥pSθ)(y∣x):=1Ly∑n=1LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T \parallel p_S^{\theta}\right)(y|x) \;:=\; \frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right) D(pT∥pSθ)(y∣x):=Ly1n=1∑LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))
更改后
LKD=1L∑n=1LDKL(pT(n)∥pS(n))\mathcal{L}_{KD} = \frac{1}{L} \sum_{n=1}^L D_{KL}(p_T(n) \parallel p_S(n))LKD=L1n=1∑LDKL(pT(n)∥pS(n))
符号说明:
- LKD\mathcal{L}_{KD}LKD:知识蒸馏(Knowledge Distillation)的总损失
- LLL:序列长度(token总数)
- nnn:第nnn个token的位置
- pT(n)p_T(n)pT(n):教师模型在第nnn步(给定前缀和输入)的token分布
- pS(n)p_S(n)pS(n):学生模型在第nnn步的token分布
- DKL(⋅∥⋅)D_{KL}(\cdot \parallel \cdot)DKL(⋅∥⋅):KL散度(衡量两分布差异)
蒸馏损失 = 每个token位置的师生分布差异的平均值
自回归分解的概率逻辑逐token散度的损失逻辑
"自回归分解"是序列概率确实是按token拆的
自回归模型)的假设是:序列的联合概率能拆解成"逐token条件概率的乘积"。
“苹果例子”(输入x="Describe: Apple"x=\text{"Describe: Apple"}x="Describe: Apple",输出序列y=("A","p","p","l","e")y=(\text{"A","p","p","l","e"})y=("A","p","p","l","e"),长度Ly=5L_y=5Ly=5)为例,自回归分解的数学表达是:
p(y∣x)=p(y1∣x)×p(y2∣y<2,x)×p(y3∣y<3,x)×p(y4∣y<4,x)×p(y5∣y<5,x)p(y|x) = p(y_1 \mid x) \times p(y_2 \mid y_{<2}, x) \times p(y_3 \mid y_{<3}, x) \times p(y_4 \mid y_{<4}, x) \times p(y_5 \mid y_{<5}, x)p(y∣x)=p(y1∣x)×p(y2∣y<2,x)×p(y3∣y<3,x)×p(y4∣y<4,x)×p(y5∣y<5,x)
意思是:“输出‘Apple’这个序列的概率”,等于"先预测第一个token‘A’的概率"×"在‘A’之后预测‘p’的概率"×"在‘A,p’之后预测‘p’的概率"……以此类推。这就是"序列概率按token拆"的具体含义。
“逐token散度"是蒸馏损失要跟着概率的拆解"同步拆”
蒸馏是让"学生模型pSθp_S^\thetapSθ的分布尽可能接近教师模型pTp_TpT的分布",衡量这种"接近度"的指标就是散度(比如KL散度)。
但问题是:我们要衡量的是"整个序列分布的接近度",而序列分布已经被自回归拆成了逐token的条件分布——由于散度(尤其是KL散度)具有"对乘积分布的可加性"(简单说:“整体分布的散度”=“各拆解后条件分布的散度之和”),所以蒸馏损失(散度)也能跟着拆成逐token的散度,再求和/平均。
对应到公式里:
公式
D(pT∥pSθ)(y∣x):=1Ly∑n=1LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T \parallel p_S^{\theta}\right)(y|x) \;:=\; \frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right) D(pT∥pSθ)(y∣x):=Ly1n=1∑LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))
或者
LKD=1L∑n=1LDKL(pT(n)∥pS(n))\mathcal{L}_{KD} = \frac{1}{L} \sum_{n=1}^L D_{KL}(p_T(n) \parallel p_S(n))LKD=L1n=1∑LDKL(pT(n)∥pS(n))
- 求和符号∑n=1Ly\sum_{n=1}^{L_y}∑n=1Ly就是"逐token拆":每个nnn对应序列中第nnn个token的位置;
- 求和项D(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\left(p_T(\cdot \mid y_{<n},x) \parallel p_S^\theta(\cdot \mid y_{<n},x)\right)D(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))就是"第nnn个token的蒸馏损失":衡量教师和学生在"基于前n−1n-1n−1个token(y<ny_{<n}y<n)和输入xxx"的条件下,对第nnn个token的概率分布差异;
- 最后除以LyL_yLy是"求平均":避免序列长度不同导致损失规模差异(比如"Apple"(5个token)和"Banana"(6个token)的损失可直接对比)。
正是把"自回归分解的概率逻辑"和"逐token散度的损失逻辑"结合起来的结果——因为序列概率按token拆了,所以衡量分布差异的散度(蒸馏损失)也按token拆了加起来,再平均得到最终的序列级蒸馏损失。
最终序列级散度就是15(D1+D2+D3+D4+D5)\frac{1}{5}(D_1+D_2+D_3+D_4+D_5)51(D1+D2+D3+D4+D5),其中D1D_1D1是"预测‘A’时的散度",D2D_2D2是"预测‘p’(基于‘A’)时的散度",完全跟着自回归的token顺序走。