当前位置: 首页 > news >正文

大模型微调中显存占用和训练时间的影响因素

BatchSize

显存占用:与batch_size呈线性关系,可理解为 M t o t a l = M f i x e d + B a t c h S i z e ∗ M p e r − s a m p l e M_{total}=M_{fixed}+BatchSize*M_{per-sample} Mtotal=Mfixed+BatchSizeMpersample,其中 M f i x e d M_{fixed} Mfixed指的是模型本身固定占用的显存(由参数数量决定)和优化器状态(也由参数数量决定)

总训练时间:理论上与BatchSize无关(总数不变,单步训练时间增加,总步数减少),但实际中随BatchSize越大,总时间可能减少(硬件并行效率提升),直到显存或硬件并行能力达到瓶颈。

截断长度(输入序列分词后的最大长度,即每条样本被大模型读取的最大长度)

1. 显存占用

在大型语言模型(如 Transformer)中,显存占用主要与模型的激活值(Activations)有关,而激活值的大小受到输入序列长度(即截断长度)的直接影响。以下是逐步分析:

激活值的定义

激活值是指模型在正向传播过程中每一层计算出的中间结果,通常存储在显存中,以便反向传播时计算梯度。对于 Transformer 模型,激活值主要与注意力机制(Self-Attention)和前馈网络(Feed-Forward Network, FFN)的计算相关。

显存占用的组成

显存占用主要包括:

  • 模型参数(权重和偏置):与模型规模(层数、隐藏维度)相关,与截断长度无关。
  • 激活值:与输入序列长度(截断长度 L L L)、批次大小(batch size B B B)、隐藏维度(hidden size H H H)和层数( N N N)成正比。
  • 梯度(训练时):与参数量和激活值大小相关。

对于激活值部分,显存占用主要来源于:

  1. 注意力机制:计算 Q ⋅ K T Q \cdot K^T QKT的注意力分数矩阵,尺寸为 ( B , L , L ) (B, L, L) (B,L,L),每层需要存储。
  2. 中间张量:如 V V V的加权和、前馈层的输出等。
数学表达式

假设: L L L:截断长度(序列长度), B B B:批次大小, H H H:隐藏维度, N N N:模型层数, P P P:浮点数精度(如 FP32 为 4 字节,FP16 为 2 字节)

激活值的显存占用近似为:
显存 激活值 ≈ N ⋅ B ⋅ L ⋅ H ⋅ P + N ⋅ B ⋅ L 2 ⋅ P \text{显存}_{\text{激活值}} \approx N \cdot B \cdot L \cdot H \cdot P + N \cdot B \cdot L^2 \cdot P 显存激活值NBLHP+NBL2P

  • 第一项 N ⋅ B ⋅ L ⋅ H ⋅ P N \cdot B \cdot L \cdot H \cdot P NBLHP:表示每层的线性张量(如 Q , K , V Q, K, V Q,K,V或 FFN 输出)的显存占用。
  • 第二项 N ⋅ B ⋅ L 2 ⋅ P N \cdot B \cdot L^2 \cdot P NBL2P:表示注意力分数矩阵的显存占用(仅在标准注意力机制中显著,若使用优化如 FlashAttention,则可能减少)。

结论:显存占用与截断长度 L L L呈线性( O ( L ) O(L) O(L))到二次方( O ( L 2 ) O(L^2) O(L2))的关系,具体取决于注意力机制的实现方式。


2. 训练时间

训练时间主要与计算量(FLOPs,浮点运算次数)和硬件并行能力有关,而截断长度会影响计算量。

计算量的组成
  1. 注意力机制:每层的计算量与 L 2 L^2 L2相关,因为需要计算 L × L L \times L L×L的注意力矩阵。
  2. 前馈网络:每层的计算量与 L L L线性相关,因为对每个 token 独立计算。

总计算量(FLOPs)近似为:
FLOPs ≈ N ⋅ B ⋅ ( 2 ⋅ L 2 ⋅ H + 4 ⋅ L ⋅ H 2 ) \text{FLOPs} \approx N \cdot B \cdot (2 \cdot L^2 \cdot H + 4 \cdot L \cdot H^2) FLOPsNB(2L2H+4LH2)

  • 2 ⋅ L 2 ⋅ H 2 \cdot L^2 \cdot H 2L2H:注意力机制的矩阵乘法(如 Q ⋅ K T Q \cdot K^T QKT softmax ⋅ V \text{softmax} \cdot V softmaxV),
  • 4 ⋅ L ⋅ H 2 4 \cdot L \cdot H^2 4LH2:前馈网络的计算(假设 FFN 隐藏层维度为 4 H 4H 4H)。
训练时间

训练时间与 FLOPs 成正比,同时受硬件并行能力(如 GPU 的计算核心数)影响。假设每秒浮点运算能力为 F GPU F_{\text{GPU}} FGPU(单位:FLOPs/s),则单次前向+反向传播的训练时间为:
时间 ≈ FLOPs F GPU ≈ N ⋅ B ⋅ ( 2 ⋅ L 2 ⋅ H + 4 ⋅ L ⋅ H 2 ) F GPU \text{时间} \approx \frac{\text{FLOPs}}{F_{\text{GPU}}} \approx \frac{N \cdot B \cdot (2 \cdot L^2 \cdot H + 4 \cdot L \cdot H^2)}{F_{\text{GPU}}} 时间FGPUFLOPsFGPUNB(2L2H+4LH2)

结论:训练时间与截断长度 L L L呈线性( O ( L ) O(L) O(L))到二次方( O ( L 2 ) O(L^2) O(L2))的关系,具体取决于注意力机制的计算占比。


3. 总结

  • 显存占用:与 L L L O ( L ) O(L) O(L) O ( L 2 ) O(L^2) O(L2)关系,取决于是否存储完整的注意力矩阵。
  • 训练时间:与 L L L O ( L ) O(L) O(L) O ( L 2 ) O(L^2) O(L2)关系,注意力机制的二次项通常更显著。

1

假设某模型大小为5GB,推理所需显存也为5GB,普通Lora微调(FP16)所需显存为5GB*2=10GB,8bit的QLora量化为5GB/2=2.5GB,4bit的QLora量化为5GB/4=1.25GB

相关文章:

  • OTP单片机调试工具之—单线数据编码
  • RCore学习记录001
  • 微信小程序threejs三维开发
  • 如何解决pymilvus中offset参数不生效的问题?
  • AI与人的智能,改变一生的思维模型【7】易得性偏差
  • 在 WSL中批量执行InSAR任务-stackSentinel.py
  • MySQL数据库知识总结
  • Redis7——进阶篇(六)
  • 小脑萎缩会致命吗?
  • Vue Router 中的导航守卫是什么?
  • 有了大语言模型还需要 RAG 做什么
  • AP AR
  • 二叉树_4_面试题汇总
  • AlphaGo 家族:从「偷看棋谱」到「自创宇宙套路」的 1008 天
  • 神经网络的基本知识
  • 生态安全的范式
  • LoRa数传、点对点通信、Mesh网络、ZigBee以及图传技术的区别和特点
  • zend server试用分析
  • 架构思维:软件建模与架构设计的关键要点
  • request模块基本使用方法
  • 浙江推动人工智能终端消费:家居机器人纳入以旧换新补贴范围
  • 浙江一教师被指殴打并威胁小学生,教育局通报涉事人被行拘
  • 2人恶意传播刘国梁谣言被处罚,媒体:以法律利剑劈谣斩邪,加快推进依法治体
  • 新华时评:博物馆正以可亲可近替代“高冷范儿”
  • 首次带人形机器人走科技节红毯,傅利叶顾捷:机器人行业没包袱,很多事都能从零开始
  • 联合国:欢迎俄乌伊斯坦布尔会谈,希望实现全面停火