当前位置: 首页 > news >正文

知识蒸馏 Knowledge Distillation 0. 基础:自回归分解与逐 token散度

知识蒸馏 Knowledge Distillation 0. 基础:自回归分解与逐 token散度

代码实践
论文 Generalized Knowledge Distillation (GKD)
On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes

序列级散度的逐 token 分解公式 或 token-level distillation loss。

在论文语境下,它是 教师分布与学生分布在序列 y 上的散度定义。

联合概率、条件概率、边缘概率

乘法法则、全概率公式、贝叶斯定理

概率链式法则(Probability Chain Rule)

序列的联合概率 分解成 基于历史的条件概率的连乘序列

公式

D⁣(pT∥pSθ)(y∣x):=1Ly∑n=1LyD⁣(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T \parallel p_S^{\theta}\right)(y|x) \;:=\; \frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right) D(pTpSθ)(yx):=Ly1n=1LyD(pT(y<n,x)pSθ(y<n,x))

符号 :=

在数学、逻辑和计算机科学中,符号 “:=” 表示 “定义为”(defined as),它的作用是给左边的新符号/表达式赋予一个明确的含义,强调“左边的内容是由右边的内容来定义的”。

为什么不用“=”(等于号)?

等于号“=”表示的是**“等价关系”**——即两边的表达式在数值、逻辑或含义上是相等的,且这种相等是基于已有的定义推导出来的。例如:

  • 2+3=52+3=52+3=5” 表示“2+32+32+3的结果等价于555”(基于加法的已有定义);
  • x=2yx=2yx=2y” 表示“xxx2y2y2y在数值上相等”(xxxyyy的含义已明确)。

而“:=”则专门用于引入新符号并规定其含义。它强调:
左边的符号是第一次出现的新符号,需要通过右边的内容来“创造”它的含义。

公式中:
D⁣(pT∥pSθ)(y∣x):=1Ly∑n=1LyD⁣(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T \parallel p_S^{\theta}\right)(y|x) \;:=\; \frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right) D(pTpSθ)(yx):=Ly1n=1LyD(pT(y<n,x)pSθ(y<n,x))

这里左边的 D(pT∥pSθ)(y∣x)D(p_T \parallel p_S^\theta)(y|x)D(pTpSθ)(yx) 是一个新引入的符号(表示“给定xxxyyy时,教师与学生分布的序列级散度”),之前没有被定义过。因此用“:=”明确:“把左边这个新符号定义为右边的平均散度表达式”。

如果换成“=”,会产生歧义:读者可能会误以为 D(pT∥pSθ)(y∣x)D(p_T \parallel p_S^\theta)(y|x)D(pTpSθ)(yx) 是一个已经被定义过的符号,现在只是在陈述它和右边表达式“相等”——但实际上这个符号是第一次出现,需要明确它的定义。

符号 ∥\parallel

公式中的 ∥\parallel 是散度(divergence)表示中的专用符号,用于分隔两个需要比较的概率分布,表示“衡量前者与后者之间的差异”。

在概率统计和信息论中,当讨论两个概率分布 pppqqq 的散度(比如KL散度、JS散度等)时,通常写作 D(p∥q)D(p \parallel q)D(pq),这里的∥\parallel可以理解为“相对于”或“与…比较”,整个表达式的含义是“分布 ppp 相对于分布 qqq 的散度”。

D⁣(pT∥pSθ)D\!\left(p_T \parallel p_S^{\theta}\right)D(pTpSθ) 中:

  • pTp_TpT 是教师模型的概率分布,
  • pSθp_S^\thetapSθ 是学生模型的概率分布,
  • ∥\parallel 分隔了这两个分布,表明散度 DDD 正在衡量“教师分布 pTp_TpT 与学生分布 pSθp_S^\thetapSθ 之间的差异”。

这个符号和几何中“平行”(比如直线 a∥ba \parallel bab)的含义完全无关,它是概率散度中的专用符号,仅用于明确“被比较的两个分布”的顺序(因为散度通常不对称,D(p∥q)≠D(q∥p)D(p \parallel q) \neq D(q \parallel p)D(pq)=D(qp),所以顺序很重要)。

完整公式

将论文中序列级散度的逐token分解公式,从外层到内层,结合"输入x=x=x=‘用5个词描述苹果:’、输出y=y=y=‘苹果是红色的水果’"的场景来说

按"外层→中层→内层"拆解:
D⁣(pT⏞内层1∥pSθ⏞内层2)(y∣x)⏞外层条件⏟外层:序列级散度:=⏟定义符号1Ly⏟中层:归一化∑n=1Ly⏟中层:求和D⁣(pT(⋅∣y<n⏞内层4,x⏞内层5)⏞内层3∥pSθ(⋅∣y<n,x)⏞内层6)⏟内层:逐token散度\underbrace{D\!\left(\overbrace{p_T}^{\text{内层1}} \parallel \overbrace{p_S^{\theta}}^{\text{内层2}}\right)\overbrace{(y|x)}^{\text{外层条件}}}_{\text{外层:序列级散度}} \;\underbrace{:=\;}_{\text{定义符号}} \;\underbrace{\frac{1}{L_y}}_{\text{中层:归一化}} \underbrace{\sum_{n=1}^{L_y}}_{\text{中层:求和}} \underbrace{D\!\left(\overbrace{p_T(\cdot \mid \overbrace{y_{<n}}^{\text{内层4}}, \overbrace{x}^{\text{内层5}})}^{\text{内层3}} \parallel \overbrace{p_S^{\theta}(\cdot \mid y_{<n},x)}^{\text{内层6}}\right)}_{\text{内层:逐token散度}} 外层:序列级散度DpT内层1pSθ内层2(yx)外层条件定义符号:=中层:归一化Ly1中层:求和n=1Ly内层:逐token散度DpT(y<n内层4,x内层5)内层3pSθ(y<n,x)内层6

层级符号类型核心含义苹果场景对应示例
外层D(⋅∥⋅)D(\cdot \parallel \cdot)D()散度运算符衡量两个概率分布的差异(非对称)D(p教师∥p学生)D(p_{\text{教师}} \parallel p_{\text{学生}})D(p教师p学生)
外层(y∣x)(y|x)(yx)条件限定符限定散度计算在"给定输入xxx和序列yyy"的场景下给定x=x=x="描述苹果:"和y=y=y=“苹果是红色的水果”
外层:=:=:=定义符号左边由右边定义序列级散度 :=:=:= 逐token散度的平均
中层1Ly\frac{1}{L_y}Ly1归一化因子对散度总和进行平均,消除序列长度偏差Ly=5L_y=5Ly=5,故15\frac{1}{5}51
中层∑n=1Ly\sum_{n=1}^{L_y}n=1Ly求和运算符遍历所有token位置,累加逐token散度∑n=15\sum_{n=1}^5n=15(累加5个token的散度)
中层nnn位置索引表示当前处理的token序号(1到LyL_yLyn=3n=3n=3对应第3个token"红色"
中层LyL_yLy序列长度输出序列yyy包含的token数量yyy有5个token,故Ly=5L_y=5Ly=5
内层pTp_TpT教师分布教师模型的固定概率分布(模仿目标)教师对"红色"的概率是0.7
内层pSθp_S^\thetapSθ学生分布学生模型的可训练概率分布(θ\thetaθ是模型参数)初始时学生对"红色"的概率是0.4
内层⋅\cdot全集占位符代表词表中所有可能的下一个token“红色”“绿色”"蓝色"等所有描述性token
内层y<ny_{<n}y<n序列前缀nnn个token之前的所有token组成的子序列n=3n=3n=3时,y<3=y_{<3}=y<3=“苹果是”
内层xxx输入变量任务指令/prompt(预测的前提)x=x=x=“用5个词描述苹果:”
内层∣\mid条件分隔符表示"在…条件下"(竖线右边是前提)pT(⋅∣"苹果是",x)p_T(\cdot \mid "苹果是",x)pT("苹果是",x)即"给定前缀和输入"

逐符号拆解(从外到内)

(一)外层符号:序列级散度的整体框架

  1. D(⋅∥⋅)D(\cdot \parallel \cdot)D()
    衡量两个概率分布的差异程度,差异越大值越大,分布完全相同时为0。括号内"⋅∥⋅\cdot \parallel \cdot"是固定格式,左边为基准分布、右边为待比较分布,且不满足对称性(即D(A∥B)≠D(B∥A)D(A \parallel B) \neq D(B \parallel A)D(AB)=D(BA))。常用类型有KL散度(衡量学生近似教师的信息损失)、JSD散度(对称且有界),比如在苹果场景中,就是衡量教师对"描述苹果"的预测分布和学生预测分布的差距。

  2. D(pT∥pSθ)D\left(p_T \parallel p_S^{\theta}\right)D(pTpSθ)
    表示教师分布pTp_TpT与学生分布pSθp_S^\thetapSθ之间的散度,是未限定具体序列和输入时的泛化定义。比如泛化地描述"教师对‘描述苹果’的预测,和学生对‘描述苹果’的预测之间的差距",此时还未指定具体生成哪句话。

  3. (y∣x)(y|x)(yx)
    将散度计算限定在"给定输入xxx和输出序列yyy"的场景下,类似条件概率p(y∣x)p(y|x)p(yx)(给定xxx生成yyy的概率),这里是"给定xxxyyy的散度",而非对所有xxxyyy的期望散度。在苹果场景中,就是"给定输入‘用5个词描述苹果:’、输出序列‘苹果是红色的水果’时",限定了散度的计算范围。

  4. :=:=:=
    代表"定义",表示左边的符号由右边的公式定义,即"序列级散度"这个概念,被定义为右边"逐token散度求和再平均"的结果。区别于普通等于号======表示左右数值相等,:=:=:=表示左边是右边的缩写或定义),比如用"序列级散度 :=:=:= 逐token散度的平均",明确两者的定义关系。

(二)中层符号:序列级到逐token级的拆解

  1. 1Ly\frac{1}{L_y}Ly1
    对"逐token散度的总和"进行归一化,避免长序列散度总和更大、短序列更小导致的训练偏差。其中LyL_yLy是序列yyy的token数量(必须为正整数),在苹果场景中,yyy有5个token,Ly=5L_y=5Ly=51Ly\frac{1}{L_y}Ly1就是15\frac{1}{5}51,用于将5个token的散度总和平均为"每个token的平均差距"。

  2. ∑n=1Ly\sum_{n=1}^{L_y}n=1Ly
    遍历序列yyy的所有token位置,将每个位置的"逐token散度"相加得到总散度。其中∑\sum是求和符号,nnn是位置索引(循环变量),n=1n=1n=1是求和起始位置(对应苹果场景中第一个token"苹果"),LyL_yLy是求和终止位置(对应苹果场景中第五个token"水果")。比如在苹果场景中,就是把第1到第5个token的逐token散度全部加起来。

(三)内层符号:逐token散度的核心细节

内层是公式核心,对应"每个token位置上师生分布的具体差异",即D⁣(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right)D(pT(y<n,x)pSθ(y<n,x))

  1. pTp_TpT
    TTT是"Teacher(教师)“的首字母,代表教师模型的概率分布,是训练成熟的大模型,预测结果更准确,作为学生的"模仿目标”,且训练过程中分布固定。在苹果场景中,就是大模型对"描述苹果"的预测分布,比如生成"红色"时,教师给"红色"的概率是0.7,给"蓝色"的概率接近0。

  2. pSθp_S^{\theta}pSθ
    SSS是"Student(学生)"的首字母,代表待蒸馏的小模型的概率分布;θ\thetaθ(theta)是学生模型的可训练参数(如神经网络的权重、偏置),训练中通过调整θ\thetaθ,让pSθp_S^\thetapSθ逐渐接近pTp_TpT。在苹果场景中,初始时θ\thetaθ未优化,学生给"红色"的概率只有0.4,训练后θ\thetaθ更新,这个概率会接近教师的0.7。

  3. ⋅\cdot(点号)
    代表"词表中所有可能的下一个token",明确散度比较的是"完整的概率分布",而非"某个具体token的概率"。比如在苹果场景中,n=3n=3n=3(前缀"苹果是")时,⋅\cdot代表"红色"“绿色”“蓝色”"圆形"等所有可能描述苹果特征的token,pT(⋅∣"苹果是",x)p_T(\cdot \mid "苹果是",x)pT("苹果是",x)就是教师对这些token的完整概率分布("红色"0.7、"绿色"0.2、"蓝色"0.01…)。

  4. y<ny_{<n}y<n
    表示"序列yyy的前缀",即第nnn个token之前所有token组成的子序列。其中yyy是当前讨论的输出序列,<n<n<n表示"取位置1到n−1n-1n1的token"(不包含位置nnn)。在苹果场景中,n=1n=1n=1y<1y_{<1}y<1是空序列,n=3n=3n=3y<3y_{<3}y<3是"苹果是",n=5n=5n=5y<5y_{<5}y<5是"苹果是红色的"。

  5. xxx
    代表模型的"输入信息",通常是任务指令、prompt或上下文,所有token的预测都依赖于xxx——同样的前缀"苹果是",输入x=x=x="描述苹果:"和x=x=x="描述天空:“时,模型预测的下一个token完全不同。在苹果场景中,xxx就是"用5个词描述苹果:”,是师生生成"苹果是红色的水果"的核心依据。

  6. ∣\mid(竖线)
    表示"在…的条件下",竖线右边是"预测的前提",左边是"被预测的对象",对应数学中条件概率的表示逻辑(如p(A∣B)p(A \mid B)p(AB)表示"给定BBB发生时AAA发生的概率")。在苹果场景中,pT(⋅∣"苹果是","描述苹果:")p_T(\cdot \mid "苹果是", "描述苹果:")pT("苹果是","描述苹果:")就是"给定输入‘描述苹果:’和前缀‘苹果是’时,教师对所有可能下一个token的概率分布"。

计算例子

  • 输入 x="描述苹果:"x = \text{"描述苹果:"}x="描述苹果:"
  • 输出序列 y=("苹果", "是", "红色", "的", "水果")y = (\text{"苹果", "是", "红色", "的", "水果"})y=("苹果", "", "红色", "", "水果")(长度 Ly=5L_y = 5Ly=5
  • 我们使用KL散度计算,公式为:DKL(p∥q)=∑ipilog⁡(piqi)D_{KL}(p \parallel q) = \sum_{i} p_i \log\left(\frac{p_i}{q_i}\right)DKL(pq)=ipilog(qipi)

逐步计算过程:

  1. 第一步:预测第1个token “苹果”

    • 前缀:y<1=()y_{<1} = ()y<1=()(空序列)
    • 教师分布:pT={"苹果":0.9,"香蕉":0.03,"橙子":0.03,"其他":0.04}p_T = \{\text{"苹果"}:0.9, \text{"香蕉"}:0.03, \text{"橙子"}:0.03, \text{"其他"}:0.04\}pT={"苹果":0.9,"香蕉":0.03,"橙子":0.03,"其他":0.04}
    • 学生分布:pS={"苹果":0.6,"香蕉":0.15,"橙子":0.15,"其他":0.1}p_S = \{\text{"苹果"}:0.6, \text{"香蕉"}:0.15, \text{"橙子"}:0.15, \text{"其他"}:0.1\}pS={"苹果":0.6,"香蕉":0.15,"橙子":0.15,"其他":0.1}
    • 第1个位置的KL散度计算:
      D1=0.9log⁡(0.90.6)+0.03log⁡(0.030.15)+0.03log⁡(0.030.15)+0.04log⁡(0.040.1)=0.9×0.405+0.03×(−1.609)+0.03×(−1.609)+0.04×(−0.916)=0.3645−0.0483−0.0483−0.0366≈0.231\begin{align*} D_1 &= 0.9\log\left(\frac{0.9}{0.6}\right) + 0.03\log\left(\frac{0.03}{0.15}\right) + 0.03\log\left(\frac{0.03}{0.15}\right) + 0.04\log\left(\frac{0.04}{0.1}\right) \\ &= 0.9 \times 0.405 + 0.03 \times (-1.609) + 0.03 \times (-1.609) + 0.04 \times (-0.916) \\ &= 0.3645 - 0.0483 - 0.0483 - 0.0366 \\ &\approx 0.231 \end{align*} D1=0.9log(0.60.9)+0.03log(0.150.03)+0.03log(0.150.03)+0.04log(0.10.04)=0.9×0.405+0.03×(1.609)+0.03×(1.609)+0.04×(0.916)=0.36450.04830.04830.03660.231
  2. 第二步:预测第2个token “是”

    • 前缀:y<2=("苹果")y_{<2} = (\text{"苹果"})y<2=("苹果")
    • 教师分布:pT={"是":0.8,"有":0.1,"像":0.05,"其他":0.05}p_T = \{\text{"是"}:0.8, \text{"有"}:0.1, \text{"像"}:0.05, \text{"其他"}:0.05\}pT={"":0.8,"":0.1,"":0.05,"其他":0.05}
    • 学生分布:pS={"是":0.5,"有":0.25,"像":0.15,"其他":0.1}p_S = \{\text{"是"}:0.5, \text{"有"}:0.25, \text{"像"}:0.15, \text{"其他"}:0.1\}pS={"":0.5,"":0.25,"":0.15,"其他":0.1}
    • 第2个位置的KL散度计算:
      D2=0.8log⁡(0.80.5)+0.1log⁡(0.10.25)+0.05log⁡(0.050.15)+0.05log⁡(0.050.1)=0.8×0.470+0.1×(−0.916)+0.05×(−1.098)+0.05×(−0.693)=0.376−0.0916−0.0549−0.0347≈0.195\begin{align*} D_2 &= 0.8\log\left(\frac{0.8}{0.5}\right) + 0.1\log\left(\frac{0.1}{0.25}\right) + 0.05\log\left(\frac{0.05}{0.15}\right) + 0.05\log\left(\frac{0.05}{0.1}\right) \\ &= 0.8 \times 0.470 + 0.1 \times (-0.916) + 0.05 \times (-1.098) + 0.05 \times (-0.693) \\ &= 0.376 - 0.0916 - 0.0549 - 0.0347 \\ &\approx 0.195 \end{align*} D2=0.8log(0.50.8)+0.1log(0.250.1)+0.05log(0.150.05)+0.05log(0.10.05)=0.8×0.470+0.1×(0.916)+0.05×(1.098)+0.05×(0.693)=0.3760.09160.05490.03470.195
  3. 第三步:预测第3个token “红色”

    • 前缀:y<3=("苹果", "是")y_{<3} = (\text{"苹果", "是"})y<3=("苹果", "")
    • 教师分布:pT={"红色":0.7,"绿色":0.2,"圆形":0.05,"其他":0.05}p_T = \{\text{"红色"}:0.7, \text{"绿色"}:0.2, \text{"圆形"}:0.05, \text{"其他"}:0.05\}pT={"红色":0.7,"绿色":0.2,"圆形":0.05,"其他":0.05}
    • 学生分布:pS={"红色":0.4,"绿色":0.1,"蓝色":0.3,"其他":0.2}p_S = \{\text{"红色"}:0.4, \text{"绿色"}:0.1, \text{"蓝色"}:0.3, \text{"其他"}:0.2\}pS={"红色":0.4,"绿色":0.1,"蓝色":0.3,"其他":0.2}
    • 第3个位置的KL散度计算:
      D3=0.7log⁡(0.70.4)+0.2log⁡(0.20.1)+0.05log⁡(0.050)+0.05log⁡(0.050.2)(注意:学生分布中"圆形"概率为0,实际计算中会用极小值ε代替避免log(∞))≈0.7×0.559+0.2×0.693+0.05×5+0.05×(−1.386)≈0.391+0.139+0.25−0.069≈0.711\begin{align*} D_3 &= 0.7\log\left(\frac{0.7}{0.4}\right) + 0.2\log\left(\frac{0.2}{0.1}\right) + 0.05\log\left(\frac{0.05}{0}\right) + 0.05\log\left(\frac{0.05}{0.2}\right) \\ &\quad \text{(注意:学生分布中"圆形"概率为0,实际计算中会用极小值ε代替避免log(∞))} \\ &\approx 0.7 \times 0.559 + 0.2 \times 0.693 + 0.05 \times 5 + 0.05 \times (-1.386) \\ &\approx 0.391 + 0.139 + 0.25 - 0.069 \\ &\approx 0.711 \end{align*} D3=0.7log(0.40.7)+0.2log(0.10.2)+0.05log(00.05)+0.05log(0.20.05)(注意:学生分布中"圆形"概率为0,实际计算中会用极小值ε代替避免log(∞)0.7×0.559+0.2×0.693+0.05×5+0.05×(1.386)0.391+0.139+0.250.0690.711
  4. 第四步:预测第4个token “的”

    • 前缀:y<4=("苹果", "是", "红色")y_{<4} = (\text{"苹果", "是", "红色"})y<4=("苹果", "", "红色")
    • 教师分布:pT={"的":0.9,"很":0.05,"非常":0.03,"其他":0.02}p_T = \{\text{"的"}:0.9, \text{"很"}:0.05, \text{"非常"}:0.03, \text{"其他"}:0.02\}pT={"":0.9,"":0.05,"非常":0.03,"其他":0.02}
    • 学生分布:pS={"的":0.7,"很":0.15,"非常":0.1,"其他":0.05}p_S = \{\text{"的"}:0.7, \text{"很"}:0.15, \text{"非常"}:0.1, \text{"其他"}:0.05\}pS={"":0.7,"":0.15,"非常":0.1,"其他":0.05}
    • 第4个位置的KL散度计算:
      D4=0.9log⁡(0.90.7)+0.05log⁡(0.050.15)+0.03log⁡(0.030.1)+0.02log⁡(0.020.05)=0.9×0.251+0.05×(−1.098)+0.03×(−1.203)+0.02×(−0.916)=0.226−0.055−0.036−0.018≈0.117\begin{align*} D_4 &= 0.9\log\left(\frac{0.9}{0.7}\right) + 0.05\log\left(\frac{0.05}{0.15}\right) + 0.03\log\left(\frac{0.03}{0.1}\right) + 0.02\log\left(\frac{0.02}{0.05}\right) \\ &= 0.9 \times 0.251 + 0.05 \times (-1.098) + 0.03 \times (-1.203) + 0.02 \times (-0.916) \\ &= 0.226 - 0.055 - 0.036 - 0.018 \\ &\approx 0.117 \end{align*} D4=0.9log(0.70.9)+0.05log(0.150.05)+0.03log(0.10.03)+0.02log(0.050.02)=0.9×0.251+0.05×(1.098)+0.03×(1.203)+0.02×(0.916)=0.2260.0550.0360.0180.117
  5. 第五步:预测第5个token “水果”

    • 前缀:y<5=("苹果", "是", "红色", "的")y_{<5} = (\text{"苹果", "是", "红色", "的"})y<5=("苹果", "", "红色", "")
    • 教师分布:pT={"水果":0.85,"食物":0.1,"东西":0.03,"其他":0.02}p_T = \{\text{"水果"}:0.85, \text{"食物"}:0.1, \text{"东西"}:0.03, \text{"其他"}:0.02\}pT={"水果":0.85,"食物":0.1,"东西":0.03,"其他":0.02}
    • 学生分布:pS={"水果":0.65,"食物":0.2,"东西":0.1,"其他":0.05}p_S = \{\text{"水果"}:0.65, \text{"食物"}:0.2, \text{"东西"}:0.1, \text{"其他"}:0.05\}pS={"水果":0.65,"食物":0.2,"东西":0.1,"其他":0.05}
    • 第5个位置的KL散度计算:
      D5=0.85log⁡(0.850.65)+0.1log⁡(0.10.2)+0.03log⁡(0.030.1)+0.02log⁡(0.020.05)=0.85×0.274+0.1×(−0.693)+0.03×(−1.203)+0.02×(−0.916)=0.233−0.069−0.036−0.018≈0.110\begin{align*} D_5 &= 0.85\log\left(\frac{0.85}{0.65}\right) + 0.1\log\left(\frac{0.1}{0.2}\right) + 0.03\log\left(\frac{0.03}{0.1}\right) + 0.02\log\left(\frac{0.02}{0.05}\right) \\ &= 0.85 \times 0.274 + 0.1 \times (-0.693) + 0.03 \times (-1.203) + 0.02 \times (-0.916) \\ &= 0.233 - 0.069 - 0.036 - 0.018 \\ &\approx 0.110 \end{align*} D5=0.85log(0.650.85)+0.1log(0.20.1)+0.03log(0.10.03)+0.02log(0.050.02)=0.85×0.274+0.1×(0.693)+0.03×(1.203)+0.02×(0.916)=0.2330.0690.0360.0180.110

最终序列级散度计算:

D(pT∥pSθ)(y∣x)=15(D1+D2+D3+D4+D5)=15(0.231+0.195+0.711+0.117+0.110)≈15×1.364=0.273D\left(p_T \parallel p_S^\theta\right)(y|x) = \frac{1}{5} \left(D_1 + D_2 + D_3 + D_4 + D_5\right) = \frac{1}{5} \left(0.231 + 0.195 + 0.711 + 0.117 + 0.110\right) \approx \frac{1}{5} \times 1.364 = 0.273 D(pTpSθ)(yx)=51(D1+D2+D3+D4+D5)=51(0.231+0.195+0.711+0.117+0.110)51×1.364=0.273

这个结果(0.273)表示在描述苹果这个任务中,学生模型与教师模型的整体差距。在训练过程中,我们会通过优化学生模型参数θ来最小化这个值,使学生的预测分布尽可能接近教师的分布。

公式与例子的对应关系:

  1. 外层结构 D⁣(pT∥pSθ)(y∣x)D\!\left(p_T \parallel p_S^{\theta}\right)(y|x)D(pTpSθ)(yx)
    对应例子中“最终序列级散度”的结果(≈0.273),表示“给定输入x=x=x=‘描述苹果:’和输出序列y=y=y=‘苹果是红色的水果’时,教师分布与学生分布的整体差距”。

  2. 1Ly\frac{1}{L_y}Ly1
    例子中Ly=5L_y=5Ly=5(序列yyy有5个token),对应15\frac{1}{5}51,用于对5个位置的散度求和后做平均。

  3. ∑n=1Ly\sum_{n=1}^{L_y}n=1Ly
    例子中是对n=1n=1n=1n=5n=5n=5的散度求和(D1+D2+D3+D4+D5D_1 + D_2 + D_3 + D_4 + D_5D1+D2+D3+D4+D5),对应公式中“遍历所有token位置累加散度”的逻辑。

  4. 内层散度 D⁣(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right)D(pT(y<n,x)pSθ(y<n,x))
    例子中每个DnD_nDn(如D1D_1D1D5D_5D5)分别对应:

    • n=1n=1n=1时:教师和学生在“空前缀+输入xxx”条件下的分布差异;
    • n=2n=2n=2时:教师和学生在“前缀‘苹果’+输入xxx”条件下的分布差异;
    • 以此类推,完全匹配公式中“逐token条件分布散度”的定义。

论文里的公式用如下表示就简单些

原来是
D⁣(pT∥pSθ)(y∣x):=1Ly∑n=1LyD⁣(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T \parallel p_S^{\theta}\right)(y|x) \;:=\; \frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right) D(pTpSθ)(yx):=Ly1n=1LyD(pT(y<n,x)pSθ(y<n,x))
更改后

LKD=1L∑n=1LDKL(pT(n)∥pS(n))\mathcal{L}_{KD} = \frac{1}{L} \sum_{n=1}^L D_{KL}(p_T(n) \parallel p_S(n))LKD=L1n=1LDKL(pT(n)pS(n))

符号说明:

  • LKD\mathcal{L}_{KD}LKD:知识蒸馏(Knowledge Distillation)的总损失
  • LLL:序列长度(token总数)
  • nnn:第nnn个token的位置
  • pT(n)p_T(n)pT(n):教师模型在第nnn步(给定前缀和输入)的token分布
  • pS(n)p_S(n)pS(n):学生模型在第nnn步的token分布
  • DKL(⋅∥⋅)D_{KL}(\cdot \parallel \cdot)DKL():KL散度(衡量两分布差异)

蒸馏损失 = 每个token位置的师生分布差异的平均值

自回归分解的概率逻辑逐token散度的损失逻辑

"自回归分解"是序列概率确实是按token拆的

自回归模型)的假设是:序列的联合概率能拆解成"逐token条件概率的乘积"
“苹果例子”(输入x="Describe: Apple"x=\text{"Describe: Apple"}x="Describe: Apple",输出序列y=("A","p","p","l","e")y=(\text{"A","p","p","l","e"})y=("A","p","p","l","e"),长度Ly=5L_y=5Ly=5)为例,自回归分解的数学表达是:
p(y∣x)=p(y1∣x)×p(y2∣y<2,x)×p(y3∣y<3,x)×p(y4∣y<4,x)×p(y5∣y<5,x)p(y|x) = p(y_1 \mid x) \times p(y_2 \mid y_{<2}, x) \times p(y_3 \mid y_{<3}, x) \times p(y_4 \mid y_{<4}, x) \times p(y_5 \mid y_{<5}, x)p(yx)=p(y1x)×p(y2y<2,x)×p(y3y<3,x)×p(y4y<4,x)×p(y5y<5,x)
意思是:“输出‘Apple’这个序列的概率”,等于"先预测第一个token‘A’的概率"×"在‘A’之后预测‘p’的概率"×"在‘A,p’之后预测‘p’的概率"……以此类推。这就是"序列概率按token拆"的具体含义。

“逐token散度"是蒸馏损失要跟着概率的拆解"同步拆”

蒸馏是让"学生模型pSθp_S^\thetapSθ的分布尽可能接近教师模型pTp_TpT的分布",衡量这种"接近度"的指标就是散度(比如KL散度)。
但问题是:我们要衡量的是"整个序列分布的接近度",而序列分布已经被自回归拆成了逐token的条件分布——由于散度(尤其是KL散度)具有"对乘积分布的可加性"(简单说:“整体分布的散度”=“各拆解后条件分布的散度之和”),所以蒸馏损失(散度)也能跟着拆成逐token的散度,再求和/平均

对应到公式里:

公式

D⁣(pT∥pSθ)(y∣x):=1Ly∑n=1LyD⁣(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\!\left(p_T \parallel p_S^{\theta}\right)(y|x) \;:=\; \frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot \mid y_{<n},x)\,\parallel\,p_S^{\theta}(\cdot \mid y_{<n},x)\right) D(pTpSθ)(yx):=Ly1n=1LyD(pT(y<n,x)pSθ(y<n,x))

或者
LKD=1L∑n=1LDKL(pT(n)∥pS(n))\mathcal{L}_{KD} = \frac{1}{L} \sum_{n=1}^L D_{KL}(p_T(n) \parallel p_S(n))LKD=L1n=1LDKL(pT(n)pS(n))

  • 求和符号∑n=1Ly\sum_{n=1}^{L_y}n=1Ly就是"逐token拆":每个nnn对应序列中第nnn个token的位置;
  • 求和项D(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x))D\left(p_T(\cdot \mid y_{<n},x) \parallel p_S^\theta(\cdot \mid y_{<n},x)\right)D(pT(y<n,x)pSθ(y<n,x))就是"第nnn个token的蒸馏损失":衡量教师和学生在"基于前n−1n-1n1个token(y<ny_{<n}y<n)和输入xxx"的条件下,对第nnn个token的概率分布差异;
  • 最后除以LyL_yLy是"求平均":避免序列长度不同导致损失规模差异(比如"Apple"(5个token)和"Banana"(6个token)的损失可直接对比)。

正是把"自回归分解的概率逻辑"和"逐token散度的损失逻辑"结合起来的结果——因为序列概率按token拆了,所以衡量分布差异的散度(蒸馏损失)也按token拆了加起来,再平均得到最终的序列级蒸馏损失

最终序列级散度就是15(D1+D2+D3+D4+D5)\frac{1}{5}(D_1+D_2+D_3+D_4+D_5)51(D1+D2+D3+D4+D5),其中D1D_1D1是"预测‘A’时的散度",D2D_2D2是"预测‘p’(基于‘A’)时的散度",完全跟着自回归的token顺序走。

http://www.dtcms.com/a/348296.html

相关文章:

  • 重学python之mro
  • 【科研绘图系列】R语言浮游植物初级生产力与光照强度的关系
  • 28.原型
  • 详解triton.jit及PTX
  • 目标检测数据集 第006期-基于yolo标注格式的汽车事故检测数据集(含免费分享)
  • vue 自定义文件选择器组件- 原生 input实现
  • 一文学习和掌握网关SpringCloudGateway
  • Java基础知识(五)
  • 南科大C++ 第二章知识储备
  • 电脑深度清理软件,免费磁盘优化工具
  • Shell脚本-如何生成随机数
  • 设置接收超时(SO_RCVTIMEO)
  • 8月精选!Windows 11 25H2 【版本号:26200.5733】
  • 牛市阶段投资指南
  • ffmpeg强大的滤镜功能
  • SingleFile网页保存插件本地安装(QQ浏览器)
  • 【图像处理基石】如何把非笑脸转为笑脸?
  • ffmpeg 问答系列-> mux 部分
  • 启动Flink SQL Client并连接到YARN集群会话
  • Node.js自研ORM框架深度解析与实践
  • 柱状图中最大的矩形+单调栈
  • STM32 入门实录:macOS 下从 0 到点亮 LED
  • Java全栈开发面试实录:从基础到实战的深度探讨
  • 微服务-19.什么是网关
  • 【论文阅读】AI 赋能基于模型的系统工程研究现状与展望
  • Redis--day12--黑马点评--附近商铺用户签到UV统计
  • Excel 表格 - 合并单元格、清除单元格格式
  • 包裹堆叠场景漏检率↓79%!陌讯多目标追踪算法在智慧物流的实践优化
  • EXCEL实现复制后倒序粘贴
  • 暗影哨兵:安全运维的隐秘防线