当前位置: 首页 > news >正文

Transformer朴素采样时,生成 T 个 token 需要的 FLOPs 计算推导过程

文章目录

  • Transformer 的朴素采样
  • 推理过程
    • 一、前提:生成过程的序列长度变化
    • 二、关键:Transformer单次前馈传递的FLOPs(以自注意力为主)
      • 1. 自注意力的核心计算(简化版)
      • 2. 单次前馈传递的主导复杂度
    • 三、总FLOPs:T步生成的累加复杂度
    • 总结:为什么是O(T³)?

Transformer 的朴素采样

在生成每个token时,都需要将整个历史序列输入到Transformer中。

假设你有一个 transformer 模型,输入 prompt(never gonna give you),模型会给出下一个 token 在词汇表上的 logits 分布,然后从中采样。一旦得到 token(up,),就把它加到 prompt 后面。再输入给 transformer,接着模型给出分布,采样得到 token(never)。依此类推。

缺点:时间复杂度很高。 生成T个token需要O(T³)的浮点运算次数(FLOPs)(因为,每次前馈传递的复杂度是O(T²),T个 token 要进行 T 次前馈传递,复杂度之和接近O(T³))

所以,这种方法效率低下,因为每次生成新token时都会重复计算历史token的表示。

推理过程

要理解朴素采样中“生成T个token需要O(T³) FLOPs”的计算逻辑,需从Transformer的前馈传递复杂度生成过程的序列长度变化两方面拆解,核心是自注意力机制的计算量随序列长度的增长规律。

一、前提:生成过程的序列长度变化

在朴素采样中,生成第kkk个token时,输入Transformer的序列长度为kkk(假设初始prompt长度为0,仅考虑生成的token;若包含初始prompt,逻辑相同,只是基数增加):

  • 生成第1个token:输入序列长度n=1n=1n=1(空序列或初始prompt,简化为1);
  • 生成第2个token:输入序列长度n=2n=2n=2(第1个token+新增的输入);
  • 生成第TTT个token:输入序列长度n=Tn=Tn=T(前T−1T-1T1个token+新增输入)。

二、关键:Transformer单次前馈传递的FLOPs(以自注意力为主)

Transformer的前馈传递中,自注意力机制是计算量最大的部分,其复杂度主导了整体FLOPs。对于长度为nnn的序列,自注意力的核心计算步骤及复杂度如下:

1. 自注意力的核心计算(简化版)

自注意力的核心是计算“每个位置对所有位置的注意力权重”,具体步骤:

  • 生成Query(Q)、Key(K)、Value(V):每个矩阵维度为n×dn \times dn×dnnn是序列长度,ddd是隐藏维度,如768),由输入序列通过线性层得到,复杂度为O(nd2)O(nd^2)O(nd2)(次要项,ddd是固定常数);
  • 计算注意力分数(QKTQK^TQKT):Q是n×dn \times dn×d,K是n×dn \times dn×dQKTQK^TQKT的结果是n×nn \times nn×n的矩阵,每个元素需要ddd次乘法和d−1d-1d1次加法,总复杂度为O(n2d)O(n^2d)O(n2d)(主导项);
  • 计算softmax和与V相乘:softmax对n×nn \times nn×n矩阵操作,复杂度O(n2)O(n^2)O(n2);与V(n×dn \times dn×d)相乘的结果是n×dn \times dn×d,复杂度O(n2d)O(n^2d)O(n2d)(主导项)。

2. 单次前馈传递的主导复杂度

忽略常数项(如ddd固定,可视为常数),自注意力的主导复杂度为O(n2)O(n^2)O(n2)
其他模块(如前馈网络FFN)的复杂度为O(nd2)O(nd^2)O(nd2),因ddd是固定值(如768),当nnn(序列长度)增大时,n2n^2n2的增长远快于nd2nd^2nd2,因此单次前馈传递的整体复杂度可近似为O(n2)O(n^2)O(n2)

三、总FLOPs:T步生成的累加复杂度

生成TTT个token的过程中,每一步的序列长度为n=1,2,...,Tn=1,2,...,Tn=1,2,...,T,每步的前馈传递复杂度为O(n2)O(n^2)O(n2)。因此总FLOPs是各步复杂度的总和:

总FLOPs≈∑n=1TO(n2)\text{总FLOPs} \approx \sum_{n=1}^T O(n^2) FLOPsn=1TO(n2)

根据数学公式,平方和的近似结果为:
∑n=1Tn2=T(T+1)(2T+1)6≈O(T3)\sum_{n=1}^T n^2 = \frac{T(T+1)(2T+1)}{6} \approx O(T^3) n=1Tn2=6T(T+1)(2T+1)O(T3)

总结:为什么是O(T³)?

  • 生成第kkk个token时,序列长度为kkk,单次前馈传递复杂度为O(k2)O(k^2)O(k2)(自注意力主导);
  • 生成TTT个token的总复杂度是12+22+...+T21^2 + 2^2 + ... + T^212+22+...+T2的累加,数学上近似为O(T3)O(T³)O(T3)

这就是朴素采样中“生成T个token需要O(T³) FLOPs”的核心逻辑——序列长度随生成步骤线性增长,而每步的计算量随序列长度的平方增长,最终导致总复杂度为三次方级。

http://www.dtcms.com/a/358559.html

相关文章:

  • sunset: 1渗透测试
  • 《HM-RAG: Hierarchical Multi-Agent Multimodal Retrieval Augmented Generation》
  • Java中使用正则表达式的正确打开方式
  • 《微服务架构从故障频发到自愈可控的实战突围方案》
  • C++抽象类
  • Photoshop - Ps 编辑图像
  • 在PowerPoint和WPS演示让蝴蝶一直跳8字舞
  • 干掉抽取壳!FART 自动化脱壳框架与 Execute 脱壳点解析
  • 迷你电脑用到什么型号的RJ45网口
  • 【系列08】端侧AI:构建与部署高效的本地化AI模型 第7章:架构设计与高效算子
  • 文件夹和文件一键加密,保护你的隐私
  • 计算机算术8-浮点加法
  • EVidenceModeler v2.1 安装与使用--生信工具58
  • 开发者效率白皮书:工具选型与使用方法论
  • 使用 JavaScript 构建 RAG(检索增强生成)库:原理与实现
  • 【Redisson 加锁源码解析】
  • 不使用if else ,实现石头剪刀布
  • 大数据在UI前端的应用深化研究:用户行为数据的跨平台关联分析
  • 思科ISR4300系列端口限速
  • 面试专栏
  • [光学原理与应用-333]:ZEMAX - 序列模式的设计过程
  • 基于CNN(卷积神经网络)的门牌号识别
  • 国标调查:赋能中国汽车行业高质量发展的关键支撑
  • 【C++】红黑树(详解)
  • 项目管理方法如何选择
  • 语音情感识别中的跨语言无监督领域自适应方法详解
  • 微服务搭建(SpringBoot + Dubbo + Nacos)
  • 【龙泽科技】汽车电气故障诊断仿真教学软件【迈腾380TSI】
  • 3.kafka常用命令
  • 元素滚动scrollIntoView