大模型训练计算显存占用
在大模型训练过程中,GPU显存中需要存储多种类型的数据,这些数据的合理管理直接影响训练效率和模型规模。需要放入GPU的关键数据类型如下:
注意: 在计算大模型训练占用的显存时,一般只计算 模型参数、梯度、优化器 的显存占用情况,模型参数、梯度、优化器 三者的参数比例一般为 1:1:3 或 1:1:2(因为有的优化器含有二阶矩,比例会相比于一般优化器要高);大模型推理时,只计算 模型参数。
一、模型参数(Parameters)
-
内容:包括神经网络的权重(weights)和偏置(bias),是模型的核心组成部分。
-
显存占用: 以LLaMA-7B模型为例,若使用FP32(32位浮点)精度存储,7B参数占用约28GB显存;使用FP16(16位浮点)则占用14GB显存。
-
混合精度训练(如FP16+FP32)可平衡计算速度和显存需求,FP16用于计算,FP32用于参数更新。
-
优化技术:通过ZeRO(零冗余优化器)将参数切分到多个GPU上,例如ZeRO-3将参数分布在所有GPU中,显存占用降低至单卡的1/N(N为GPU数量)。
二、梯度(Gradients)
-
内容:反向传播过程中计算的参数更新方向。
-
显存占用: 梯度与模型参数维度相同,使用FP16存储时占14GB;如果使用FP32存储,则占28GB(以7B模型为例)。
-
梯度累积技术可减少显存占用,但会增加训练时间。
-
优化技术:ZeRO-2将梯度切分到多GPU,显存占用减少8倍。
三、优化器状态(Optimizer States)
-
内容:包括优化器(如Adam)维护的动量(momentum)、二阶矩估计(variance)等中间状态。
-
显存占用: Adam优化器需存储FP32精度的参数、动量和二阶矩,三者共占用84GB显存(以7B模型为例)。
-
优化器状态是显存占用的最大头,占总需求的50%以上。
-
优化技术:ZeRO-1将优化器状态切分到多GPU,显存占用减少4倍。
四、激活值(Activations)
-
内容:前向传播过程中各层的中间计算结果,用于反向传播。
-
显存占用: 与批次大小(batch size)和序列长度正相关,例如处理512x512x512的3D数据时,单个样本占用134MB,32批次则需4.2GB。
-
激活值占显存比例通常低于参数和梯度,但仍需注意长序列场景下的显存爆炸问题。
-
优化技术:激活检查点(Activation Checkpointing)选择性保存部分激活值,其余通过重计算恢复,可减少30%-50%显存。
五、输入数据批次(Batch Data)
-
内容:预处理后的输入数据(如文本、图像张量),通常以批量形式加载到GPU。
-
显存占用: 数据格式影响显存需求,例如uint8比float32节省75%空间。
-
使用数据并行时,每个GPU存储部分批次数据,需注意多worker加载时的内存消耗。
-
优化技术:
• 在CPU端保持低精度数据(如uint8),GPU端实时转换为float并标准化。
• 使用高效数据加载器(如PyTorch DataLoader)减少CPU-GPU传输延迟。
显存优化策略总结
- 精度选择:优先使用混合精度(FP16/BF16)减少参数和梯度占用。
- 分布式切分:采用ZeRO-3同时切分参数、梯度和优化器状态,显存需求降低至单卡的1/N。
- 激活管理:结合检查点技术与梯度累积,平衡显存与计算开销。