详解 Transformer 激活值的内存占用公式
文章目录
- 激活值的内存公式
- 首先明确变量含义
- 左边项:sbh×34sbh \times 34sbh×34(MLP及点乘操作的激活值)
- 右边项:5abs25abs^25abs2(softmax及注意力的二次项)
- 1. 注意力分数矩阵(核心二次项)
- 2. softmax的中间激活值
- 3. 其他二次项
- 总和:约5abs2abs^2abs2
- 总结
激活值的内存公式
首先明确变量含义
在Transformer模型的内存分析中,这些变量通常表示:
- sss:序列长度(sequence length,输入文本的token数量);
- bbb:批次大小(batch size,一次训练的样本数);
- hhh:隐藏层维度(hidden dimension,每个token的特征向量维度);
- aaa:注意力头数(number of attention heads,多头注意力的头数量)。
左边项:sbh×34sbh \times 34sbh×34(MLP及点乘操作的激活值)
Transformer的每个编码器/解码器层包含多头注意力和MLP两个核心模块,这两个模块会产生大量中间激活值(需要临时存储的张量),这些激活值的总内存可以汇总为sbh×34sbh \times 34sbh×34,具体拆解如下:
1. 多头注意力模块的激活值(约12sbhsbhsbh)
多头注意力的核心计算流程为:
输入xxx(形状b×s×hb \times s \times hb×s×h)→ 线性变换生成Q,K,VQ, K, VQ,K,V → 计算注意力分数 → 与VVV加权求和 → 输出线性变换。
其中需要存储的激活值包括:
- Q,K,VQ, K, VQ,K,V:每个都是b×s×hb \times s \times hb×s×h(总3sbhsbhsbh);
- 注意力输出的中间结果(与VVV加权求和后,未经过最终线性变换):b×s×hb \times s \times hb×s×h(1sbhsbhsbh);
- 多头注意力的最终输出(经过线性变换后):b×s×hb \times s \times hb×s×h(1sbhsbhsbh);
- 层归一化(LayerNorm)的中间变量(如归一化前的残差、均值、方差等):约2sbhsbhsbh;
- 其他点乘操作(如QQQ与KTK^TKT的中间结果,虽然是二次项,但此处“点乘”可能指线性变换的矩阵乘法输出):约5sbhsbhsbh(不同实现细节可能有差异)。
2. MLP模块的激活值(约22sbhsbhsbh)
MLP通常由“线性变换→激活函数→线性变换”组成,且中间维度会扩展(通常为4h4h4h),激活值包括:
- 第一个线性变换的输出(扩展到4h4h4h):b×s×4hb \times s \times 4hb×s×4h(4sbhsbhsbh);
- 激活函数(如GELU)的输出(与上一步同形状):b×s×4hb \times s \times 4hb×s×4h(4sbhsbhsbh);
- 第二个线性变换的输出(还原到hhh):b×s×hb \times s \times hb×s×h(1sbhsbhsbh);
- 层归一化的中间变量(残差、均值、方差等):约2sbhsbhsbh;
- 其他辅助计算(如dropout的掩码、临时缓存等):约11sbhsbhsbh(不同框架实现差异较大)。
总和:约34sbhsbhsbh
多头注意力(12sbhsbhsbh)+ MLP(22sbhsbhsbh)的激活值总和约为34sbhsbhsbh,这就是左边项的来源。
右边项:5abs25abs^25abs2(softmax及注意力的二次项)
注意力机制中存在与序列长度sss相关的二次项激活值(形状含s×ss \times ss×s),这些是内存消耗的“大头”,具体来源如下:
1. 注意力分数矩阵(核心二次项)
多头注意力中,QQQ(b×a×s×h/ab \times a \times s \times h/ab×a×s×h/a)与KTK^TKT(b×a×h/a×sb \times a \times h/a \times sb×a×h/a×s)的点积会生成注意力分数矩阵,形状为b×a×s×sb \times a \times s \times sb×a×s×s(每个头、每个样本都有一个s×ss \times ss×s的矩阵),其内存为b×a×s×s=abs2b \times a \times s \times s = abs^2b×a×s×s=abs2。
2. softmax的中间激活值
对注意力分数矩阵应用softmax后,结果仍为b×a×s×sb \times a \times s \times sb×a×s×s(与输入同形状),需要额外存储,内存也是abs2abs^2abs2。
3. 其他二次项
- 注意力权重(softmax输出)与VVV(b×a×s×h/ab \times a \times s \times h/ab×a×s×h/a)相乘的中间结果(未拼接多头前):约2abs22abs^22abs2(不同实现的临时缓存);
- 掩码(mask)相关的临时张量(如填充掩码、因果掩码):约abs2abs^2abs2。
总和:约5abs2abs^2abs2
上述二次项激活值总和约为5abs2abs^2abs2,即sbh×5ashsbh \times 5\frac{as}{h}sbh×5has(推导:ash×sbh=abs2\frac{as}{h} \times sbh = abs^2has×sbh=abs2)。
总结
激活值的内存公式是对Transformer层中两类核心激活值的汇总:
- 左边34sbh34sbh34sbh:来自MLP和注意力中的“线性变换输出”(与sss、bbb、hhh线性相关);
- 右边5abs25abs^25abs2:来自注意力机制中的“二次项”(与s2s^2s2相关,是长序列场景下的内存瓶颈)。
这两类激活值共同决定了Transformer在训练/推理时的内存占用,尤其是当sss很大时(如长文本),二次项abs2abs^2abs2会成为主导因素。