「GRPO训练参数详解:理解Batch构成与生成数量的关系」
💡 背景介绍
在复现 VLM-R1 项目并运行 GRPO 算法时,遇到如下报错:
ValueError: The global train batch size (4 x 1) must be evenly divisible by the number of generations per prompt (8). Given the current train batch size, the valid values for the number of generations are: [2, 4].
为此,我深入梳理了 GRPO 训练中的关键参数:nproc_per_node、per_device_train_batch_size、gradient_accumulation_steps 与 num_generations。
这些参数共同决定了全局 batch size,并需满足特定的整除约束,否则训练会中断。本文将逐一解析它们的含义、相互关系及合法配置方式,帮助更高效地调试多卡多进程训练中的常见问题。
1️⃣ 1. nproc_per_node:每个节点的 GPU 数量
这是
torchrun的参数,不是 Python 脚本的参数。表示你启动的并行进程数,每个进程通常绑定一个 GPU,例如:
torchrun --nproc_per_node=4 ...
表示你会使用 4 个 GPU,每个 GPU 运行一份模型的副本(通过 DDP 等机制通信同步梯度)。
2️⃣ 2. per_device_train_batch_size:每张 GPU 上的 batch size
表示**每个进程(即每张 GPU)**上处理多少个样本。
总的训练 batch size =
nproc_per_node × per_device_train_batch_size,例如:
| nproc_per_node | per_device_train_batch_size | 总 batch size |
|---|---|---|
| 2 | 4 | 8 |
| 4 | 2 | 8 |
3️⃣ 3. gradient_accumulation_steps:梯度累积步数
为了显存节省,每
gradient_accumulation_steps次前向+反向传播后,才做一次权重更新。有效全局 batch size = 总 batch size × accumulation steps,比如:
| nproc | batch/gpu | grad_acc_steps | 有效 batch |
|---|---|---|---|
| 2 | 4 | 1 | 8 |
| 2 | 4 | 2 | 16 |
4️⃣ 4. num_generations(generations_per_prompt):每个 prompt 生成多少个候选输出
是 RLHF / GRPO 中的超参数,表示对每条样本(prompt)采样几个 candidate generation。
通常用于 reward 比较排序、RL 采样等。
🔗 它们之间的约束关系
以下是核心公式和常见限制:
✅ 全局训练 batch size:
global_batch_size = nproc_per_node × per_device_train_batch_size × gradient_accumulation_steps
✅ 与 num_generations 的关系:
某些 RL 模型(比如 GRPO)要求:
global_batch_size % num_generations == 0
因为每个 prompt 生成
num_generations个 candidate,需要 batch 中刚好包含多个完整 prompt。如果不能整除,会在构造训练 batch 时出错。
✅ 举个例子
假设你设置:
torchrun --nproc_per_node=2
--per_device_train_batch_size=2
--gradient_accumulation_steps=2
--num_generations=4
那么计算如下:
global_batch_size = 2 × 2 × 2 = 8num_generations = 48 % 4 == 0✅ 合法
如果你设置 num_generations = 3,就会报错 ❌,因为无法对每组 3 个 candidate 匹配成完整的 prompt block。
💡 常用搭配建议
| GPU 数 | batch/gpu | grad_acc_steps | num_gen | 是否兼容 | 有效 batch size |
|---|---|---|---|---|---|
| 1 | 2 | 2 | 2 | ✅ | 4 |
| 2 | 2 | 2 | 4 | ✅ | 8 |
| 2 | 1 | 1 | 3 | ❌ | 2(不能整除) |
| 4 | 2 | 1 | 4 | ✅ | 8 |
🔍 一句话总结它们的关系:
你必须确保
global_batch_size = nproc_per_node × per_device_train_batch_size × gradient_accumulation_steps能被num_generations整除,否则会在 RL 或 reward 比较中 batch 构造失败。
