论文阅读:speculative decoding
Fast Inference from Transformers via Speculative Decoding
论文地址:https://arxiv.org/pdf/2211.17192
speculative sampling
为了从分布 p ( x ) p(x) p(x) 中采样,我们实际上是从分布 q ( x ) q(x) q(x) 中采样 x x x,如果 q ( x ) ≤ p ( x ) q(x) \leq p(x) q(x)≤p(x),则保留该样本;如果 q ( x ) > p ( x ) q(x) > p(x) q(x)>p(x),则以概率 1 − p ( x ) q ( x ) 1 - \frac{p(x)}{q(x)} 1−q(x)p(x) 拒绝该样本,并重新从调整后的分布 p ′ ( x ) = norm ( max ( 0 , p ( x ) − q ( x ) ) ) p'(x) = \text{norm}(\max(0, p(x)-q(x))) p′(x)=norm(max(0,p(x)−q(x))) 中采样。对于任何分布 p ( x ) p(x) p(x) 和 q ( x ) q(x) q(x),以及以此方式采样的 x x x,确实有 x ∼ p ( x ) x \sim p(x) x∼p(x)。
给定通过在条件前缀上运行 M q M_q Mq 获得的分布 q ( x ) q(x) q(x),我们可以采样一个标记 x 1 ∼ q ( x ) x_1 \sim q(x) x1∼q(x)。然后,我们通过在前缀上运行 M p M_p Mp 来计算分布 p ( x ) p(x) p(x),同时并行地推测性地计算下一个标记 x 2 x_2 x2 的分布,即在前缀上追加 x 1 x_1 x1 后运行 M p M_p Mp。一旦两项计算都完成,我们就按上述方式处理:如果 x 1 x_1 x1 被拒绝,我们丢弃 x 2 x_2 x2 的计算,并从调整后的分布中重新采样 x 1 x_1 x1;如果 x 1 x_1 x1 被接受,我们就保留两个标记。算法 1 将这一想法推广为一次采样 1 到 γ + 1 \gamma + 1 γ+1 个标记。
分析
有几个证明需要注意一下:
单次算法期望能生成的token
-
单次算法期望能生成的token数量服从几何分布,但是求和项是有限制的,这里推导下
-
接受率β的定义
设目标模型分布为p(x)
,草稿模型分布为q(x)
。草稿模型生成的单个token被目标模型接受的概率为:
β = ∑ x min ( q ( x ) , p ( x ) ) \beta = \sum_x \min\left(q(x), p(x)\right) β=x∑min(q(x),p(x))
- 拒绝率α的定义
α = 1 − β = 1 − ∑ x min ( p ( x ) , q ( x ) ) x \alpha = 1 - \beta = 1 - \sum_x \min(p(x), q(x)) x α=1−β=1−x∑min(p(x),q(x))x
-
假设每个token的接受事件独立且同分布(i.i.d.),草稿模型一次生成
K
个token: -
首次拒绝发生在位置
r
的概率为:P ( r ) = ( 1 − β ) β r − 1 ( 1 ≤ r ≤ K ) P(r) = (1-\beta) \beta^{r-1} \quad (1 \leq r \leq K) P(r)=(1−β)βr−1(1≤r≤K)
所有token均被接受 的概率为: β K \beta^K βK
-
综上期望能生成的token数量为:
γ = ∑ r = 1 K r ⋅ P ( r ) ⏟ 拒绝前生成的token + K ⋅ β K ⏟ 全接受时生成K个token \gamma = \underbrace{\sum_{r=1}^K r \cdot P(r)}_{\text{拒绝前生成的token}} + \underbrace{K \cdot \beta^K}_{\text{全接受时生成K个token}} γ=拒绝前生成的token r=1∑Kr⋅P(r)+全接受时生成K个token K⋅βK
代入 P ( r ) P(r) P(r) 后展开:
γ = ∑ r = 1 K r ⋅ ( 1 − β ) β r − 1 + K β K \gamma = \sum_{r=1}^K r \cdot (1-\beta) \beta^{r-1} + K \beta^K γ=r=1∑Kr⋅(1−β)βr−1+KβK
- 几何级数求和
几何级数求和公式为:
对 ∑ r = 1 K r β r − 1 \sum_{r=1}^K r \beta^{r-1} ∑r=1Krβr−1 求和处理:
- 令 S = ∑ r = 1 K β r − 1 S = \sum_{r=1}^K \beta^{r-1} S=∑r=1Kβr−1:
S = 1 + β + β 2 + ⋯ + β K − 1 = 1 − β K 1 − β S = 1 + \beta + \beta^2 + \cdots + \beta^{K-1} = \frac{1-\beta^K}{1-\beta} S=1+β+β2+⋯+βK−1=1−β1−βK
- 对 S S S 求导:
∑ r = 1 K r β r − 1 = d d β ( ∑ r = 0 K β r ) = d d β ( 1 − β K + 1 1 − β ) = 1 − ( K + 1 ) β K + K β K + 1 ( 1 − β ) 2 \sum_{r=1}^K r \beta^{r-1} = \frac{d}{d\beta} \left( \sum_{r=0}^K \beta^r \right) = \frac{d}{d\beta} \left( \frac{1-\beta^{K+1}}{1-\beta} \right) = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} ∑r=1Krβr−1=dβd(∑r=0Kβr)=dβd(1−β1−βK+1)=(1−β)21−(K+1)βK+KβK+1
- 代入γ表达式:
γ = ( 1 − β ) ⋅ 1 − ( K + 1 ) β K + K β K + 1 ( 1 − β ) 2 + K β K = 1 − ( K + 1 ) β K + K β K + 1 1 − β + K β K \gamma = (1-\beta) \cdot \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} + K\beta^K = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{1-\beta} + K\beta^K γ=(1−β)⋅(1−β)21−(K+1)βK+KβK+1+KβK=1−β1−(K+1)βK+KβK+1+KβK
- 化简:
γ = 1 − β K 1 − β \gamma = \frac{1 - \beta^K}{1-\beta} γ=1−β1−βK
物理意义:
- 当 K → ∞ K \to \infty K→∞时, γ → 1 1 − β = 1 α \gamma \to \frac{1}{1-\beta} = \frac{1}{\alpha} γ→1−β1=α1(理想无限长草稿)。
- 例如 β \beta β = 0.8` 时, γ max = 5 \gamma_{\text{max}} = 5 γmax=5,即平均每次生成5个token。
得证
Walltime的时间优化
定理 3.8:算法 1 在总运行时间上的预期改进因子为
‘ 1 − α γ + 1 ( 1 − α ) ( γ c + 1 ) ‘ `\frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)}` ‘(1−α)(γc+1)1−αγ+1‘
证明:
记运行目标模型 M p M_p Mp 单步的成本为 T T T。
算法 1 的单次运行成本为 T c γ + T Tc\gamma + T Tcγ+T(其中 c γ T c\gamma T cγT用于运行近似模型 M q M_q Mq γ \gamma γ 次, T T T 用于运行 M p M_p Mp 一次)。
根据单次算法期望能生成的token算法推导,单次运行平均生成 token 数量为 1 − α γ + 1 1 − α \dfrac{1 - \alpha^{\gamma + 1}}{1 - \alpha} 1−α1−αγ+1。
因此,使用算法 1 生成单个 token 的总体预期成本为:
( c γ + 1 ) ( 1 − α ) 1 − α γ + 1 T ‘ \frac{(c\gamma + 1)(1 - \alpha)}{1 - \alpha^{\gamma + 1}}T` 1−αγ+1(cγ+1)(1−α)T‘
由于标准解码算法生成单个 token 的成本为 T
,
比较可得上述改进因子。∎
(注:符号 “∎” 表示证明结束)
关键术语说明:
英文术语 | 中文翻译 | 符号 | 含义 |
---|---|---|---|
walltime | 总运行时间 | - | 算法从启动到结束的时钟时间 |
expected improvement factor | 预期改进因子 | - | 优化后时间开销的缩减比例 |
cost per step | 单步成本 | T T T | 目标模型 M p M_p Mp 推理一个 token 的时间 |
approximation model | 近似模型 | M q M_q Mq | 快速但低精度的草稿模型 |
tokens | 标记(Token) | - | 模型生成的基本文本单位 |
rejection rate | 拒绝率 | α \alpha α | 草稿模型 M q M_q Mq 的 token 被目标模型 M p M_p Mp 拒绝的概率 |
γ \gamma γ | 生成长度 | γ \gamma γ | 草稿模型单次运行的 token 生成数 |
cost ratio | 成本比 | c c c | M q M_q Mq 与 M p M_p Mp 的单步时间比值( 0 < c < 1 0 < c < 1 0<c<1) |
公式解析:
- 改进因子
1 − α γ + 1 ( 1 − α ) ( γ c + 1 ) \frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)} (1−α)(γc+1)1−αγ+1
- 分子 1 − α γ + 1 1 - \alpha^{\gamma+1} 1−αγ+1:草稿模型连续生成
\gamma
个 token 均未被拒绝的概率补偿 - 分母 ( 1 − α ) (1-\alpha) (1−α):单 token 接受率, γ c + 1 \gamma c + 1 γc+1:草稿+验证的总时间成本
该值 >1 时表示加速,值越大加速效果越显著
- 单 token 成本公式
( c γ + 1 ) ( 1 − α ) 1 − α γ + 1 T \frac{(c\gamma+1)(1-\alpha)}{1-\alpha^{\gamma+1}}T 1−αγ+1(cγ+1)(1−α)T
- 分子 ( c γ + 1 ) ( 1 − α ) T (c\gamma+1)(1-\alpha)T (cγ+1)(1−α)T:草稿生成+验证的实际计算量
- 分母 1 − α γ + 1 1-\alpha^{\gamma+1} 1−αγ+1:有效 token 产出的概率加权
操作数计算
操作数的计算量也是类似的,直接贴结论了
( 1 − α ) ( γ c ^ + γ + 1 ) 1 − α γ + 1 \frac{(1-\alpha)(\gamma \hat{c}+\gamma+1)}{1-\alpha^{\gamma+1}} 1−αγ+1(1−α)(γc^+γ+1)
Reference
https://arxiv.org/pdf/2211.17192