Transformer朴素采样时,生成 T 个 token 需要的 FLOPs 计算推导过程
文章目录
- Transformer 的朴素采样
- 推理过程
- 一、前提:生成过程的序列长度变化
- 二、关键:Transformer单次前馈传递的FLOPs(以自注意力为主)
- 1. 自注意力的核心计算(简化版)
- 2. 单次前馈传递的主导复杂度
- 三、总FLOPs:T步生成的累加复杂度
- 总结:为什么是O(T³)?
Transformer 的朴素采样
在生成每个token时,都需要将整个历史序列输入到Transformer中。
假设你有一个 transformer 模型,输入 prompt(never gonna give you
),模型会给出下一个 token 在词汇表上的 logits 分布,然后从中采样。一旦得到 token(up,
),就把它加到 prompt 后面。再输入给 transformer,接着模型给出分布,采样得到 token(never
)。依此类推。
缺点:时间复杂度很高。 生成T个token需要O(T³)的浮点运算次数(FLOPs)(因为,每次前馈传递的复杂度是O(T²),T个 token 要进行 T 次前馈传递,复杂度之和接近O(T³))
所以,这种方法效率低下,因为每次生成新token时都会重复计算历史token的表示。
推理过程
要理解朴素采样中“生成T个token需要O(T³) FLOPs”的计算逻辑,需从Transformer的前馈传递复杂度和生成过程的序列长度变化两方面拆解,核心是自注意力机制的计算量随序列长度的增长规律。
一、前提:生成过程的序列长度变化
在朴素采样中,生成第kkk个token时,输入Transformer的序列长度为kkk(假设初始prompt长度为0,仅考虑生成的token;若包含初始prompt,逻辑相同,只是基数增加):
- 生成第1个token:输入序列长度n=1n=1n=1(空序列或初始prompt,简化为1);
- 生成第2个token:输入序列长度n=2n=2n=2(第1个token+新增的输入);
- …
- 生成第TTT个token:输入序列长度n=Tn=Tn=T(前T−1T-1T−1个token+新增输入)。
二、关键:Transformer单次前馈传递的FLOPs(以自注意力为主)
Transformer的前馈传递中,自注意力机制是计算量最大的部分,其复杂度主导了整体FLOPs。对于长度为nnn的序列,自注意力的核心计算步骤及复杂度如下:
1. 自注意力的核心计算(简化版)
自注意力的核心是计算“每个位置对所有位置的注意力权重”,具体步骤:
- 生成Query(Q)、Key(K)、Value(V):每个矩阵维度为n×dn \times dn×d(nnn是序列长度,ddd是隐藏维度,如768),由输入序列通过线性层得到,复杂度为O(nd2)O(nd^2)O(nd2)(次要项,ddd是固定常数);
- 计算注意力分数(QKTQK^TQKT):Q是n×dn \times dn×d,K是n×dn \times dn×d,QKTQK^TQKT的结果是n×nn \times nn×n的矩阵,每个元素需要ddd次乘法和d−1d-1d−1次加法,总复杂度为O(n2d)O(n^2d)O(n2d)(主导项);
- 计算softmax和与V相乘:softmax对n×nn \times nn×n矩阵操作,复杂度O(n2)O(n^2)O(n2);与V(n×dn \times dn×d)相乘的结果是n×dn \times dn×d,复杂度O(n2d)O(n^2d)O(n2d)(主导项)。
2. 单次前馈传递的主导复杂度
忽略常数项(如ddd固定,可视为常数),自注意力的主导复杂度为O(n2)O(n^2)O(n2)。
其他模块(如前馈网络FFN)的复杂度为O(nd2)O(nd^2)O(nd2),因ddd是固定值(如768),当nnn(序列长度)增大时,n2n^2n2的增长远快于nd2nd^2nd2,因此单次前馈传递的整体复杂度可近似为O(n2)O(n^2)O(n2)。
三、总FLOPs:T步生成的累加复杂度
生成TTT个token的过程中,每一步的序列长度为n=1,2,...,Tn=1,2,...,Tn=1,2,...,T,每步的前馈传递复杂度为O(n2)O(n^2)O(n2)。因此总FLOPs是各步复杂度的总和:
总FLOPs≈∑n=1TO(n2)\text{总FLOPs} \approx \sum_{n=1}^T O(n^2) 总FLOPs≈n=1∑TO(n2)
根据数学公式,平方和的近似结果为:
∑n=1Tn2=T(T+1)(2T+1)6≈O(T3)\sum_{n=1}^T n^2 = \frac{T(T+1)(2T+1)}{6} \approx O(T^3) n=1∑Tn2=6T(T+1)(2T+1)≈O(T3)
总结:为什么是O(T³)?
- 生成第kkk个token时,序列长度为kkk,单次前馈传递复杂度为O(k2)O(k^2)O(k2)(自注意力主导);
- 生成TTT个token的总复杂度是12+22+...+T21^2 + 2^2 + ... + T^212+22+...+T2的累加,数学上近似为O(T3)O(T³)O(T3)。
这就是朴素采样中“生成T个token需要O(T³) FLOPs”的核心逻辑——序列长度随生成步骤线性增长,而每步的计算量随序列长度的平方增长,最终导致总复杂度为三次方级。