LLM模型的中间激活值估计
之前分析了模型参数量估计
https://blog.csdn.net/liliang199/article/details/151839842
由于LLM模型比较大,一次训练迭代,除模型参数、梯度、优化器状态外,占用显存大头就是forward计算得到的中间激活值。中间激活需要保存,以便在backward计算梯度时使用。
这里参考网络资料,分析什么是中间激活,并估计LLM训练过程的中间激活显存占用量。
1 中间激活定义
中间激活指的是forward计算得到的,并在backward需要用到的所有张量。
中间激活不包含模型参数、优化器状态,但包含dropout用到的mask矩阵。
为简化过程,只考虑激活中显存占用大头,忽略掉一些小的显存占用。
比如,layer normalization,中间激活指的是层的输入、输入均值 和方差
。输入包含bsh个数字,均值和方差包含bs个数字,h比较大如4096,layer normalization估计中间激活仅考虑bsh部分,忽略bs部分。
2 中间激活估计
LLM训练,中间激活一般采用float16/bfloat16,每个元素2个bytes。例外是dropout mask矩阵,每个元素1个byte。为统一认知,以下分析采用单位bytes,不是元素个数。
为方便理解,采用从块到层的顺序,先self-attention块、mlp块、layer norm,然后transformer层,最后LLM,逐步分析中间激活的显存占用。
2.1 self-attention
self-attention块的计算公式如下.
1)激活Q、K、V
Q、K、V,共同输入x就是中间激活,x形状[b, s, h],元素个数bsh,占用显存2 * bsh = 2bsh。
2)激活QK^T
QK^T矩阵乘,需要保存中间激活Q、K,两者形状[b, s, h],占用显存大小 2 * 2 * bsh = 4bsh。
softmax函数,需要保存函数的输入,占用显存大小为
,这里a表示注意力头数。
Q的形状为: [b, head_num, s, per_head_hidden_size]
K^T的形状为: [b, head_num, per_head_hidden_size, s]
QK^T 的形状为: [b, head_num, s, s],元素个数为,占用显存大小为
3)softmax dropout mask激活
计算完 softmax() 函数后,会进行dropout操作。需要保存一个mask矩阵,mask矩阵的形状与 QK^T 相同,占用显存大小为 。
4)激活score
scoreV,计算V上的attention,即 score⋅V ,需要保存 score ,大小为 ;以及 V ,大小为 2bsh 。二者占用显存大小合计为
。
5)输出映射激活
计算输出映射以及一个dropout操作。输入映射需要保存其输入,大小为2bsh;dropout需要保存mask矩阵,大小为bsh。二者占用显存大小合计为 3bsh。
将1)-5)中间激活显存占用相加,self-attention块总的中间激活显存占用 。
2.2 mlp
mlp计算公式如下
W将内在维度从h升到4h,W2将内在维度从4h降到h。
1)W1对应线性层输入x_out,占用显存 2bsh
2)激活函数f_gelu保存输入,其内在维度为4h,占用显存8bsh
3)W2对应线性层保存输入,占用显存8bsh
4)最后有一个dropout,需要保存其mask矩阵,占用显存为bsh
对于MLP,需要显存保存的中间激活位19bsh。
2.3 layer norm
self-attention和mlp分别后接一个layer normalization,每个layer norm保存输入2bsh,两个layer normalization的中间激活占用的显存大小为4bsh。
2.4 transformer
每个transformer层包含了一个self-attention块和一个MLP块,并分别对应layer normalization。
综合self-attention、mlp、layer norm,每个transformer层需要保存中间激活
2.5 LLM
对于整个LLM,包含embeding层、l层transformer、最后输出层。
embedding不需要中间激活,隐藏维h较大,层数l较深时,输出层中间激活量很小,可忽略。
因此对于l层transformer的LLM,中间激活占用显存大小近似表示为
3 中间激活与模型参数
一次训练迭代中,到底是模型参数、梯度、优化器状态占用显存多,还是中间激活占用显存多,他们与输入数据的大小有没有关系。
3.1 原理分析
首先,模型参数、梯度、优化器状态占用显存大小,与优化器类型和模型参数量有关,与输入数据的大小无关。
其次,但中间激活占用显存大小与输入数据大小(批次大小和序列长度)正相关,随批次大小和序列长度的增大,中间激活占用的显存会同步增大。
训练遇到显存OOM,即显存不够或Out Of Memory,常用做法是减少批次大小,这种方式的目的就是减少中间激活占用的显存,因为训练样本中序列程度是不能变的,通用模型参数、梯度、优化器占用显存也是类似。
3.2 GPT3-175B示例
这里以GPT3-175B为例,直观对比模型参数与中间激活的显存大小。
GPT3-175B,参数量约175B,层数96,隐藏维度12288,注意力头数96。
假设采用混合精度训练,模型参数和中间激活均采用float16数据类型,每个元素占2个bytes。
3.2.1 模型参数
GPT3的模型参数量为175B,占用显存大小约为
GPT3-175B大约需要占用350GB显存。
3.2.2 中间激活
LLM中间激活占用显存大小计算公式如下。
假设训练数据序列长度s为2048,以此为基础对比不同批次大小b时,中间激活占用的显存大小。
当b=1时,中间激活占用显存275414777856 bytes,约275 GB,模型参数的0.79倍。
当b=16时,中间激活占用显存4406636445696 bytes,约4406 GB,大约模型参数的12倍。
当b=64时,中间激活占用显存17626545782784 bytes,17626GB,大约模型参数的50倍。
当b=256时,中间激活占用显存70506183131136 bytes,70506GB,大约模型参数的201倍。
数据显示,随批次大小增大,中间激活占用显存远超模型参数。
3.2.3 激活重计算技术
为有效利用显存训练模型,通常采用激活重计算技术来减少中间激活,理论上可将中间激活显存占用从O(n)减少到 O(sqrt(n)),但需要增加一次额外的forward计算。
激活重计算技术,本质上是“时间换空间”。
reference
---
LLM模型的计算量估计
https://blog.csdn.net/liliang199/article/details/152081156
LLM模型的参数量估计
https://blog.csdn.net/liliang199/article/details/151839842
分析transformer模型的参数量、计算量、中间激活、KV cache
https://zhuanlan.zhihu.com/p/624740065
LLM模型的计算量与参数量的关系
https://blog.csdn.net/liliang199/article/details/152095274