当前位置: 首页 > news >正文

详解 Transformer 激活值的内存占用公式

激活值的内存公式

在这里插入图片描述

首先明确变量含义

在Transformer模型的内存分析中,这些变量通常表示:

  • sss:序列长度(sequence length,输入文本的token数量);
  • bbb:批次大小(batch size,一次训练的样本数);
  • hhh:隐藏层维度(hidden dimension,每个token的特征向量维度);
  • aaa:注意力头数(number of attention heads,多头注意力的头数量)。

左边项:sbh×34sbh \times 34sbh×34(MLP及点乘操作的激活值)

Transformer的每个编码器/解码器层包含多头注意力MLP两个核心模块,这两个模块会产生大量中间激活值(需要临时存储的张量),这些激活值的总内存可以汇总为sbh×34sbh \times 34sbh×34,具体拆解如下:

1. 多头注意力模块的激活值(约12sbhsbhsbh

多头注意力的核心计算流程为:
输入xxx(形状b×s×hb \times s \times hb×s×h)→ 线性变换生成Q,K,VQ, K, VQ,K,V → 计算注意力分数 → 与VVV加权求和 → 输出线性变换。
其中需要存储的激活值包括:

  • Q,K,VQ, K, VQ,K,V:每个都是b×s×hb \times s \times hb×s×h(总3sbhsbhsbh);
  • 注意力输出的中间结果(与VVV加权求和后,未经过最终线性变换):b×s×hb \times s \times hb×s×h(1sbhsbhsbh);
  • 多头注意力的最终输出(经过线性变换后):b×s×hb \times s \times hb×s×h(1sbhsbhsbh);
  • 层归一化(LayerNorm)的中间变量(如归一化前的残差、均值、方差等):约2sbhsbhsbh
  • 其他点乘操作(如QQQKTK^TKT的中间结果,虽然是二次项,但此处“点乘”可能指线性变换的矩阵乘法输出):约5sbhsbhsbh(不同实现细节可能有差异)。

2. MLP模块的激活值(约22sbhsbhsbh

MLP通常由“线性变换→激活函数→线性变换”组成,且中间维度会扩展(通常为4h4h4h),激活值包括:

  • 第一个线性变换的输出(扩展到4h4h4h):b×s×4hb \times s \times 4hb×s×4h(4sbhsbhsbh);
  • 激活函数(如GELU)的输出(与上一步同形状):b×s×4hb \times s \times 4hb×s×4h(4sbhsbhsbh);
  • 第二个线性变换的输出(还原到hhh):b×s×hb \times s \times hb×s×h(1sbhsbhsbh);
  • 层归一化的中间变量(残差、均值、方差等):约2sbhsbhsbh
  • 其他辅助计算(如dropout的掩码、临时缓存等):约11sbhsbhsbh(不同框架实现差异较大)。

总和:约34sbhsbhsbh

多头注意力(12sbhsbhsbh)+ MLP(22sbhsbhsbh)的激活值总和约为34sbhsbhsbh,这就是左边项的来源。

右边项:5abs25abs^25abs2(softmax及注意力的二次项)

注意力机制中存在与序列长度sss相关的二次项激活值(形状含s×ss \times ss×s),这些是内存消耗的“大头”,具体来源如下:

1. 注意力分数矩阵(核心二次项)

多头注意力中,QQQb×a×s×h/ab \times a \times s \times h/ab×a×s×h/a)与KTK^TKTb×a×h/a×sb \times a \times h/a \times sb×a×h/a×s)的点积会生成注意力分数矩阵,形状为b×a×s×sb \times a \times s \times sb×a×s×s(每个头、每个样本都有一个s×ss \times ss×s的矩阵),其内存为b×a×s×s=abs2b \times a \times s \times s = abs^2b×a×s×s=abs2

2. softmax的中间激活值

对注意力分数矩阵应用softmax后,结果仍为b×a×s×sb \times a \times s \times sb×a×s×s(与输入同形状),需要额外存储,内存也是abs2abs^2abs2

3. 其他二次项

  • 注意力权重(softmax输出)与VVVb×a×s×h/ab \times a \times s \times h/ab×a×s×h/a)相乘的中间结果(未拼接多头前):约2abs22abs^22abs2(不同实现的临时缓存);
  • 掩码(mask)相关的临时张量(如填充掩码、因果掩码):约abs2abs^2abs2

总和:约5abs2abs^2abs2

上述二次项激活值总和约为5abs2abs^2abs2,即sbh×5ashsbh \times 5\frac{as}{h}sbh×5has(推导:ash×sbh=abs2\frac{as}{h} \times sbh = abs^2has×sbh=abs2)。

总结

激活值的内存公式是对Transformer层中两类核心激活值的汇总:

  • 左边34sbh34sbh34sbh:来自MLP和注意力中的“线性变换输出”(与sssbbbhhh线性相关);
  • 右边5abs25abs^25abs2:来自注意力机制中的“二次项”(与s2s^2s2相关,是长序列场景下的内存瓶颈)。

这两类激活值共同决定了Transformer在训练/推理时的内存占用,尤其是当sss很大时(如长文本),二次项abs2abs^2abs2会成为主导因素。

http://www.dtcms.com/a/350866.html

相关文章:

  • SOME/IP-SD报文中 Entry Format(条目格式)-理解笔记5
  • 算法题记录01:
  • 0826xd
  • Trip Footprints 旅行App开发全流程解析
  • UALink是什么?
  • 数字化转型:概念性名词浅谈(第四十二讲)
  • 牛客周赛 Round 106(小苯的方格覆盖/小苯的数字折叠/ 小苯的波浪加密器/小苯的数字变换/小苯的洞数组构造/ 小苯的数组计数)
  • 撤回git 提交
  • 算法训练营day62 图论⑪ Floyd 算法精讲、A star算法、最短路算法总结篇
  • C# 中常见的 五大泛型约束
  • [系统架构设计师]应用数学(二十一)
  • 云计算学习笔记——Linux用户和组的归属权限管理、附加权限、ACL策略管理篇
  • 联邦雪框架FedML自学---第四篇---案例一
  • 浅谈:运用幂的性质
  • 程序的“烽火台”:信号的产生与传递
  • 【基础-单选】使用http发起网络请求,需要以下哪种权限?
  • C6.2:小信号、交流电流增益分析
  • 立轴式小型混凝土搅拌机的设计含14张CAD
  • 客户生命周期价值帮助HelloFresh优化其营销支出
  • 快速了解工业相机中的连续采集、软触发、硬触发和同步触发以及PTP同步触发
  • Spring介绍
  • Linux iptables 防火墙
  • Linux网络编程基础API
  • [灵动微电子六步换向(方波控制)方案MM32BIN560C] 六步换向实现和规律
  • PostgreSQL诊断系列(2/6):锁问题排查全攻略——揪出“阻塞元凶”
  • RK3568 Linux驱动学习——pinctrl和gpio子系统
  • onnx入门教程(四)——ONNX 模型的修改与调试
  • Day24: NumPy 奥德赛:用科学计算的魔法征服数据宇宙!
  • 32.Ansible平台搭建
  • 2024年09月 Python(二级)真题解析#中国电子学会#全国青少年软件编程等级考试