在训练和推理过程中 对 token 数量的处理方式的差异
训练时要求每个 batch 的 token 数相等
推理时对 token 数量无严格要求
在训练和推理过程中,对 token 数量的处理方式不同,主要源于它们的目标和计算效率的差异。以下是详细解释:
1. 训练时要求每个 batch 的 token 数相等
在训练过程中,固定每个 batch 的 token 数量(而非句子数量)是为了实现以下目标:
(1) 计算效率与硬件优化
-
批处理(Batching):GPU/TPU 等硬件对规整的张量(固定形状)计算效率最高。如果每个 batch 的 token 数不等,需要对短序列填充(Padding)到最长序列的长度,导致大量无效计算(填充部分参与计算但无意义)。
-
显存利用率:固定 token 数量可以更精确控制显存占用,避免因某些 batch 过长导致显存溢出(OOM)。
(2) 稳定的梯度更新
-
训练动态性:如果每个 batch 的 token 数波动较大,梯度更新的幅度会不稳定(例如,一个 batch 包含 1000 token,另一个包含 5000 token,后者对模型参数的更新影响更大)。
-
学习率调整:固定 token 数量后,学习率的设计和优化(如 Adam 的梯度归一化)会更一致,避免因 batch 内 token 数差异引入噪声。
(3) 实现方法
-
动态批处理(Dynamic Batching):将训练数据按相似长度分组,尽量减少填充(如 Hugging Face 的
DataCollatorForSeq2Seq
)。 -
固定 token 数:例如,设定每个 batch 总 token 数为
4096
,可能包含 8 条长度为 512 的样本,或 16 条长度为 256 的样本。
2. 推理时对 token 数量无严格要求
在推理过程中,动态调整输入长度是常见做法,原因包括:
(1) 无需梯度更新
-
推理是单次前向计算,不涉及反向传播和参数优化,因此无需保持 batch 间的一致性。
-
即使输入长度不同,只需对当前输入填充到其实际长度即可,计算完成后丢弃填充部分。
(2) 实际场景需求
-
输入长度可变:用户请求的文本长度天然不同(例如短问题 vs 长文档)。
-
流式处理:某些场景需要逐步生成 token(如聊天机器人),无法预先固定长度。
(3) 性能优化手段
-
KV Cache:在自回归生成(如 GPT)中,缓存已计算的 Key-Value 矩阵,避免重复计算历史 token,此时动态长度对效率影响较小。
-
内存管理:推理框架(如 vLLM、TGI)会优化显存分配,支持可变长度输入。
3. 特殊情况与注意事项
-
训练时的可变长度:某些框架支持“动态批处理”(如 NVIDIA 的
DALI
),在 batch 内填充到最小必要长度,但仍需控制总 token 数以避免显存问题。 -
推理时的批处理:如果需要同时处理多个请求(如 API 服务),仍会通过动态批处理(如 ORT 的
padding+masking
)提升吞吐量,但容忍长度差异。
阶段 | 目标 | Token 数量处理 | 原因 |
---|---|---|---|
训练 | 高效学习参数 | 固定每个 batch 的 token 数 | 硬件效率、梯度稳定性 |
推理 | 灵活响应请求 | 动态适应输入长度 | 无需梯度、实际需求 |
这种差异本质上是训练追求稳定性与效率,而推理追求灵活性与实用性的结果。