【假设微调1B模型,一个模型参数是16bit,计算需要多少显存?】
好的,作为资深AI专家,我将为您详细拆解全量微调 (Full Fine-Tuning) 和高效微调 (LoRA, QLoRA) 的显存占用计算过程。
第一部分:全量微调 (Full Fine-Tuning) 1B 模型
对于一个参数量为 1B (10亿) 的模型,进行全量微调时,显存占用主要由以下四部分组成:
- 模型权重 (Model Weights)
- 梯度 (Gradients)
- 优化器状态 (Optimizer States)
- 前向激活 (Forward Activations)
我们通常使用 字节 (Bytes) 作为单位。1B parameters = 1e9 parameters
。
1. 模型权重 (FP16)
在训练时,为了计算效率和精度,我们通常使用混合精度训练。模型权重保存在显存中,通常以 16-bit 浮点数 (FP16) 格式存储。
1 parameter
占2 bytes
。- 计算公式:
Model Weights = 2 * Number of Parameters
- 计算:
2 bytes/param * 1e9 params = 2e9 bytes ≈ 2 GB
2. 梯度 (Gradients)
在反向传播过程中,每个参数都会计算出一个梯度,用于更新权重。梯度通常也以 FP16 格式存储。
1 gradient
占2 bytes
。- 计算公式:
Gradients = 2 * Number of Parameters
- 计算:
2 bytes/param * 1e9 params = 2e9 bytes ≈ 2 GB
3. 优化器状态 (Optimizer States)
优化器状态是显存占用的大头。以最常用的 AdamW 优化器为例,它为每个参数需要维护两个状态:
- 一阶动量 (m):FP32格式,占
4 bytes
。 - 二阶动量 (v):FP32格式,占
4 bytes
。 - 主权重副本 (Master Weight Copy):为了提升优化精度,AdamW 还会在 FP32 中保存一份模型权重的副本,占
4 bytes
。
- 每个参数在 AdamW 优化器下占用的显存:
4 (m) + 4 (v) + 4 (master weights) = 12 bytes
。 - 计算公式:
Optimizer States = 12 * Number of Parameters
- 计算:
12 bytes/param * 1e9 params = 12e9 bytes ≈ 12 GB
注意:如果使用像 SGD 这样更简单的优化器(只需要动量,约
8 bytes/param
),显存会少一些,但 Adam/AdamW 是当前的主流选择。
4. 前向激活 (Forward Activations / Activations)
在训练的前向传播过程中,需要保存中间计算结果(激活值),以便在反向传播时计算梯度。这部分是最难精确估算的,因为它严重依赖于:
- 模型结构 (Transformer, CNN, RNN)
- 序列长度 (Sequence Length)
- 批次大小 (Batch Size)
- 激活检查点 (Gradient Checkpointing) 技术
一个广泛使用的 经验估算公式 来自 OpenAI 的论文《Scaling Laws for Neural Language Models》:
- Activations (Bytes) ≈
Seq_Len * Batch_Size * Hidden_Dim * (34 + (5 * Seq_Len * Attn_Heads) / Hidden_Dim))
为了简化计算,我们通常认为激活所占用的显存大约是 模型权重的 1 到 3 倍。对于一个 1B 的 Transformer 模型,一个合理的估计是:
- Activations ≈ 1 * Model Weights (如果使用了梯度检查点技术)
- Activations ≈ 2-3 * Model Weights (如果未使用梯度检查点技术)
我们取一个中间值进行估算:
- 计算公式 (保守估计):
Activations ≈ 2 * Model Weights
- 计算:
2 * 2 GB = 4 GB
全量微调总显存估算
将以上四部分相加:
- Model Weights: ~2 GB
- Gradients: ~2 GB
- Optimizer States: ~12 GB
- Activations: ~4 GB (保守估计)
- 总计 (Estimated Total VRAM):
2 + 2 + 12 + 4 = 20 GB
结论:全量微调一个 1B 模型,显存需求大约在 20GB 以上。考虑到 CUDA 上下文等额外开销,建议使用 至少 24GB 显存 的显卡(如 RTX 3090, RTX 4090, RTX 3090 Ti, A5000)才能稳妥地进行。
第二部分:高效微调 (Parameter-Efficient Fine-Tuning)
高效微调的核心思想是冻结原始模型的绝大部分参数,只引入和训练一小部分额外参数,从而极大减少需要存储的梯度值和优化器状态。
1. LoRA (Low-Rank Adaptation)
原理:在模型的线性层(如 Attention 的 QKV 投影)旁注入一个低秩分解的旁路矩阵(Adapter)。假设原始矩阵维度是 d x d
,LoRA 将其分解为 B (d x r)
和 A (r x d)
,其中 r << d
(秩 r
通常很小,如 8, 16, 64)。
- 可训练参数量:
2 * (LoRA 模块数量) * d * r
- 对于 1B 模型,主要 LoRA 模块集中在 Attention 的 QKV 和 MLP 的上下投影层。假设我们只对 Attention 的 QKV 投影应用 LoRA,那么可训练参数量大约为原始参数量的 0.1% 到 1%。我们取
r=8
,可训练参数量约为4 Million (4e6)
。
- 对于 1B 模型,主要 LoRA 模块集中在 Attention 的 QKV 和 MLP 的上下投影层。假设我们只对 Attention 的 QKV 投影应用 LoRA,那么可训练参数量大约为原始参数量的 0.1% 到 1%。我们取
显存计算:
- Model Weights: 原始 1B FP16 权重被冻结,仍需加载到显存。~2 GB。
- Gradients: 只计算 LoRA 参数的梯度。
2 bytes/param * 4e6 params ≈ 8 MB
。 - Optimizer States: 只对 LoRA 参数使用 AdamW 优化器。
12 bytes/param * 4e6 params ≈ 48 MB
。 - Activations: 由于前向传播仍然需要计算原始模型的完整图,激活值显存占用与全量微调几乎相同。这是我们使用 LoRA 也无法大幅减少的部分,仍然是 ~4 GB。
LoRA 总显存估算:
2 GB (Weights) + 4 GB (Activations) + ~0.05 GB (Gradients + Optimizer States) ≈ 6.05 GB
结论:使用 LoRA 微调 1B 模型,显存需求大幅降低至约 6-8 GB。这使得在 12GB 甚至 8GB 的消费级显卡上微调大模型成为可能。
2. QLoRA (Quantized LoRA)
QLoRA 是 LoRA 的进一步优化,它通过引入 4-bit 量化来极致地降低显存占用。
原理:
- 4-bit 量化权重: 将原始 FP16 的模型权重量化成 4-bit 格式(如 NF4),然后即时反量化到 FP16 进行计算。权重存储占用减少为原来的 1/4。
- 分页优化器: 利用 CPU RAM 来处理优化器状态可能出现的显存峰值。
- 双重量化: 对量化常数进行二次量化,进一步节省空间。
显存计算:
- Model Weights: 原始 1B 权重以 4-bit 形式存储。
0.5 bytes/param * 1e9 params = 0.5e9 bytes ≈ 0.5 GB
。- (注意:计算时仍需一份反量化的 FP16 副本,但QLoRA的巧妙设计使其可以按需动态完成,峰值显存占用主要还是这 0.5 GB 的 4-bit 存储)。
- Gradients: 同 LoRA,只计算 LoRA 参数的梯度。
2 bytes/param * 4e6 params ≈ 8 MB
。 - Optimizer States: 同 LoRA,只对 LoRA 参数使用优化器。
12 bytes/param * 4e6 params ≈ 48 MB
。QLoRA 的分页优化器特性可以防止这部分在显存中爆掉。 - Activations: 仍然是最大的开销。QLoRA 无法减少这部分。仍然是 ~4 GB。
QLoRA 总显存估算:
0.5 GB (4-bit Weights) + 4 GB (Activations) + ~0.05 GB (Gradients + Optimizer States) ≈ 4.55 GB
结论:使用 QLoRA 微调 1B 模型,显存需求可以进一步降低至约 5-6 GB。这几乎让任何一款现代的消费级显卡(如 RTX 3060 12G, RTX 2060 12G)都能胜任微调 1B 模型的任务。
总结对比
微调方法 | 模型权重 | 梯度 | 优化器状态 | 激活值 | 总计显存 (估算) |
---|---|---|---|---|---|
全量微调 (AdamW) | 2 GB | 2 GB | 12 GB | 4 GB | ~20 GB |
LoRA | 2 GB (FP16) | ~8 MB | ~48 MB | 4 GB | ~6.1 GB |
QLoRA | 0.5 GB (4-bit) | ~8 MB | ~48 MB | 4 GB | ~4.6 GB |
核心洞察:
- 全量微调的显存杀手是优化器状态。
- LoRA 通过大幅减少可训练参数量,几乎消灭了梯度和优化器状态的显存占用。
- QLoRA 在此基础上,通过量化模型权重,进一步攻克了模型加载本身的显存问题。
- 激活值是高效微调中难以压缩的部分,它成为了微调超大规模模型(如 30B+)时的新瓶颈。对此,梯度检查点 (Gradient Checkpointing) 是必须使用的技术,它可以用计算时间换显存空间,将激活值显存占用减少到约
模型参数大小的 1倍
。
推荐文章:LLMem: Estimating GPU Memory Usage for Fine-Tuning Pre-Trained LLMs
英文全称:
- QLoRA:Efficient Finetuning of Quantized LLMs
- LoRA: Low-Rank Adaptation of Large Language Models
技术原文:Training language models to follow instructions with human feedback