当前位置: 首页 > news >正文

LLM模型的中间激活值估计

之前分析了模型参数量估计

https://blog.csdn.net/liliang199/article/details/151839842

由于LLM模型比较大,一次训练迭代,除模型参数、梯度、优化器状态外,占用显存大头就是forward计算得到的中间激活值。中间激活需要保存,以便在backward计算梯度时使用。

这里参考网络资料,分析什么是中间激活,并估计LLM训练过程的中间激活显存占用量。

1 中间激活定义

中间激活指的是forward计算得到的,并在backward需要用到的所有张量。

中间激活不包含模型参数、优化器状态,但包含dropout用到的mask矩阵。

为简化过程,只考虑激活中显存占用大头,忽略掉一些小的显存占用。

比如,layer normalization,中间激活指的是层的输入、输入均值 \mu和方差 \sigma ^{2} 。输入包含bsh个数字,均值和方差包含bs个数字,h比较大如4096,layer normalization估计中间激活仅考虑bsh部分,忽略bs部分。

2 中间激活估计

LLM训练,中间激活一般采用float16/bfloat16,每个元素2个bytes。例外是dropout mask矩阵,每个元素1个byte。为统一认知,以下分析采用单位bytes,不是元素个数。

为方便理解,采用从块到层的顺序,先self-attention块、mlp块、layer norm,然后transformer层,最后LLM,逐步分析中间激活的显存占用。

2.1 self-attention 

self-attention块的计算公式如下.

Q = xW_{Q}, K = xW_{K}, V = xW_{V}

x_{out} = softmax(\frac{QK^{T}}{\sqrt{h}})VW_{o} + x

1)激活Q、K、V

Q、K、V,共同输入x就是中间激活,x形状[b, s, h],元素个数bsh,占用显存2 * bsh = 2bsh。

2)激活QK^T

QK^T矩阵乘,需要保存中间激活Q、K,两者形状[b, s, h],占用显存大小 2 * 2 * bsh = 4bsh。

softmax函数,需要保存函数的输入QK^{T},占用显存大小为2bs^{2}a,这里a表示注意力头数。

score = softmax(\frac{QK^{T}}{\sqrt{d_k}})

Q的形状为: [b, head_num, s, per_head_hidden_size]

K^T的形状为: [b, head_num, per_head_hidden_size, s]

QK^T 的形状为: [b, head_num, s, s],元素个数为bs^2a,占用显存大小为2bs^2a

3)softmax dropout mask激活

计算完 softmax() 函数后,会进行dropout操作。需要保存一个mask矩阵,mask矩阵的形状与 QK^T 相同,占用显存大小为 bs^2a

4)激活score

scoreV,计算V上的attention,即 score⋅V ,需要保存 score ,大小为 2bs^2a;以及 V ,大小为 2bsh 。二者占用显存大小合计为 2bs^2a + 2bsh

5)输出映射激活

计算输出映射以及一个dropout操作。输入映射需要保存其输入,大小为2bsh;dropout需要保存mask矩阵,大小为bsh。二者占用显存大小合计为 3bsh。

将1)-5)中间激活显存占用相加,self-attention块总的中间激活显存占用 5bs^2a + 11bsh

2.2 mlp 

mlp计算公式如下

x = f_{gelu}(x_{out}W_1)W_2 + x_{out}

W将内在维度从h升到4h,W2将内在维度从4h降到h。

1)W1对应线性层输入x_out,占用显存 2bsh

2)激活函数f_gelu保存输入,其内在维度为4h,占用显存8bsh

3)W2对应线性层保存输入,占用显存8bsh

4)最后有一个dropout,需要保存其mask矩阵,占用显存为bsh

对于MLP,需要显存保存的中间激活位19bsh。

2.3 layer norm

self-attention和mlp分别后接一个layer normalization,每个layer norm保存输入2bsh,两个layer normalization的中间激活占用的显存大小为4bsh。

2.4 transformer

每个transformer层包含了一个self-attention块和一个MLP块,并分别对应layer normalization。

综合self-attention、mlp、layer norm,每个transformer层需要保存中间激活5bs^2a + 34bsh

2.5 LLM

对于整个LLM,包含embeding层、l层transformer、最后输出层。

embedding不需要中间激活,隐藏维h较大,层数l较深时,输出层中间激活量很小,可忽略。

因此对于l层transformer的LLM,中间激活占用显存大小近似表示为

(5bs^2a + 34bsh) * l

3 中间激活与模型参数

一次训练迭代中,到底是模型参数、梯度、优化器状态占用显存多,还是中间激活占用显存多,他们与输入数据的大小有没有关系。

3.1 原理分析

首先,模型参数、梯度、优化器状态占用显存大小,与优化器类型和模型参数量有关,与输入数据的大小无关。

其次,但中间激活占用显存大小与输入数据大小(批次大小和序列长度)正相关,随批次大小和序列长度的增大,中间激活占用的显存会同步增大。

训练遇到显存OOM,即显存不够或Out Of Memory,常用做法是减少批次大小,这种方式的目的就是减少中间激活占用的显存,因为训练样本中序列程度是不能变的,通用模型参数、梯度、优化器占用显存也是类似。

3.2 GPT3-175B示例

这里以GPT3-175B为例,直观对比模型参数与中间激活的显存大小。

GPT3-175B,参数量约175B,层数96,隐藏维度12288,注意力头数96。

假设采用混合精度训练,模型参数和中间激活均采用float16数据类型,每个元素占2个bytes。

3.2.1 模型参数

GPT3的模型参数量为175B,占用显存大小约为 

2 \times 175 \times 10^9 = 350 GB

GPT3-175B大约需要占用350GB显存。

3.2.2 中间激活

LLM中间激活占用显存大小计算公式如下。

(5bs^2h + 34bsh) \times l

假设训练数据序列长度s为2048,以此为基础对比不同批次大小b时,中间激活占用的显存大小。

当b=1时,中间激活占用显存275414777856 bytes,约275 GB,模型参数的0.79倍。

当b=16时,中间激活占用显存4406636445696 bytes,约4406 GB,大约模型参数的12倍。

当b=64时,中间激活占用显存17626545782784 bytes,17626GB,大约模型参数的50倍。

当b=256时,中间激活占用显存70506183131136 bytes,70506GB,大约模型参数的201倍。

数据显示,随批次大小增大,中间激活占用显存远超模型参数。

3.2.3 激活重计算技术

为有效利用显存训练模型,通常采用激活重计算技术来减少中间激活,理论上可将中间激活显存占用从O(n)减少到 O(sqrt(n)),但需要增加一次额外的forward计算。

激活重计算技术,本质上是“时间换空间”。

reference

---

LLM模型的计算量估计

https://blog.csdn.net/liliang199/article/details/152081156

LLM模型的参数量估计

https://blog.csdn.net/liliang199/article/details/151839842

分析transformer模型的参数量、计算量、中间激活、KV cache

https://zhuanlan.zhihu.com/p/624740065

LLM模型的计算量与参数量的关系

https://blog.csdn.net/liliang199/article/details/152095274

http://www.dtcms.com/a/420342.html

相关文章:

  • 网站做哪些比较赚钱方法网站策划与建设阶段的推广方法
  • 企业品牌网站建设网站背景素材
  • LlamaIndex智能体Agents开发-记忆管理
  • idea学习日记10: 字符串相关类的底层原理
  • 瑞幸咖啡网络营销策划方案沧州百度seo
  • 2025年智慧差旅平台推荐
  • 静态网页模板免费网站富源县住房和城乡建设局网站
  • python建设网站全国网站建设人员数量
  • 海外云服务器数据同步,如何确保全球业务数据一致性
  • iframe通信
  • win8风格手机网站模板如何进外贸大公司网站
  • 个人能为公司网站备案吗微信制作小程序的软件
  • 电商关于信用卡支付小记
  • java-IO流-缓冲流
  • SpringBoot实现简单图形验证码
  • platform设备驱动实验
  • 建最便宜的网站要多少钱wordpress 移动端页码
  • 图论算法刷题的第四十七天
  • 牛客周赛 Round 111(小红的阶梯/小红的数组取数/小红抽卡/小红的好数对/小芳的排列构造小红的排列构造)
  • 东莞网站建设服务湖南网站开发 d岚鸿
  • 计划书网站推广的目录怎么做一级a做爰片啪网站
  • 【剑斩OFFER】优雅的解法——三数之和
  • C++之拷贝构造(浅拷贝与深拷贝)、this指针、内联函数
  • 销售网站开发步骤网站宝二级域名怎么设置
  • 上海浦东建筑建设网站手机端网站制作
  • 【高并发内存池——项目】page cache 回收内存
  • 深圳网站建设便宜信科网络企业网站建设有几种
  • 济南网站建设92jzh知名的设计公司网站
  • MySQL表的内外连接(重点)
  • 使用 SynMatrix 的同轴腔滤波器设计和优化