分布式评估 AUC 乱飞
分布式评估 AUC 乱飞:DDP all_gather
导致 label/pred 错位,四步修复(Cursor × Codex × CodeBuddy 协作 Debug)
在本人的实践操作中,多卡训练时,验证 AUC/AP 时高时低,甚至比单卡差一截;换种 batch_size 或改
drop_last
后曲线又“起飞”。让本来就为黑盒模型的深度学习更加黑盒。因此,本章结合自己的debug经验来讲解。
❓ Bug 现象
单卡 GPU:AUC 稳定在 0.86±0.01
双卡 DDP 分布式训练:AUC 在 0.62~0.91 抖动;改
drop_last
、batch_size
,曲线形态改变但仍不稳虽然打印每卡本地 AUC 正常;但是做“全局汇总再算”就异常
📽️ 场景复现
比较常见的错误写法:直接
all_gather
每步的pred
/label
,忽视尾批大小不同与拼接顺序。
import torch, torch.distributed as dist def gather_step_wrong(pred, label):# pred: [B, 1], label: [B]ws = dist.get_world_size()pred_list = [torch.zeros_like(pred) for _ in range(ws)]label_list = [torch.zeros_like(label) for _ in range(ws)] # ❌ 直接 all_gather:要求每个 rank 张量同形状,否则会截断/复用缓存dist.all_gather(pred_list, pred) # <-- 尾批 B 不同 => 错位dist.all_gather(label_list, label) pred_all = torch.cat(pred_list, dim=0)label_all = torch.cat(label_list, dim=0)return pred_all, label_all
可能的触发条件
1️⃣ drop_last=False
且 len(dataset)
不是 world_size
的整数倍 → 各 rank 尾批 B
不同。
2️⃣ 验证集中过滤/采样导致每卡步数不同。
3️⃣ 步内先 shuffle
再同步导致拼接顺序不一致(极端情况下)。
4️⃣ 用 torchmetrics.AUROC
时同时在 step 与 epoch 做同步,导致重复/错序聚合。
Debug过程
1️⃣ 二分——各自计算 vs. 全局计算
在每个 rank 上各自算 AUC(仅用本地数据),数值正常。
把各 rank 的
pred/label
收齐后再算,全局 AUC 异常 → 问题在汇总阶段。
2️⃣ 长度与顺序核查(Cursor 自动标注)
在 gather
后打印:
print(rank, pred.shape[0], label.shape[0])
发现拼接后样本数对不上;有时
pred_all.size(0) != label_all.size(0)
(典型错位信号)。
3️⃣ 还原“错位”机理(ChatGPT 解释)
all_gather
要求各 rank 张量形状一致;若尾批大小不同,常见“权宜之计”是预分配最大长度再all_gather
,却忘了按各自真实长度截断,从而把padding也当成真实样本。即使长度对上,顺序也可能不一致(例如 rank1 的第 i 个样本在全局排序后不在 rank0 对应位置)。
4️⃣ 复现 MRE(Codex 生成)
Codex 生成了一个 2 卡、不同尾批的假数据脚本,一跑即现“全局 AUC 飘忽”的现象,锁定根因。
调整代码
关键点:先同步各 rank 的真实长度 → 按 max_len 进行 padding →
all_gather
→ 按长度回切 → 在 rank0 统一拼接并计算。同时固定全局顺序(按global_offset + local_index
)。
# ddp_metric_gather.py —— 可直接复用 import torch, torch.distributed as dist def gather_varlen_tensor(x: torch.Tensor, dim=0):"""变长安全 all_gather:返回 rank0 上拼接后的张量;其他 rank 返回 None"""assert x.is_cuda, "put tensors on CUDA for NCCL"world = dist.get_world_size()rank = dist.get_rank() # 1) 同步各自真实长度len_local = torch.tensor([x.size(dim)], device=x.device, dtype=torch.int64)lens = [torch.zeros_like(len_local) for _ in range(world)]dist.all_gather(lens, len_local)lens = torch.stack(lens).squeeze(-1) # [world]max_len = int(lens.max().item()) # 2) 按 max_len padding 到同形状pad_shape = list(x.shape)pad_shape[dim] = max_len - x.size(dim)pad = torch.zeros(pad_shape, device=x.device, dtype=x.dtype)x_pad = torch.cat([x, pad], dim=dim) # 3) all_gather 到各自的缓冲区gather_list = [torch.zeros_like(x_pad) for _ in range(world)]dist.all_gather(gather_list, x_pad) # 4) 仅在 rank0 回切并拼接(按 lens 截断)if rank == 0:parts = []for r in range(world):end = int(lens[r].item())slc = [slice(None)] * x.dim()slc[dim] = slice(0, end)parts.append(gather_list[r][tuple(slc)])return torch.cat(parts, dim=dim)else:return None @torch.no_grad() def gather_preds_labels(pred: torch.Tensor, label: torch.Tensor):# pred [B,1] / [B,C];label [B] / [B,C]pred_all = gather_varlen_tensor(pred, dim=0)label_all = gather_varlen_tensor(label, dim=0)# 仅 rank0 计算指标,其他 rank 返回 Noneif dist.get_rank() == 0:return pred_all.detach().cpu(), label_all.detach().cpu()return None, None
使用方式(验证/评估阶段):
model.eval() with torch.inference_mode():preds_local, labels_local = [], []for batch in val_loader:x, y = batch["img"].cuda(non_blocking=True), batch["label"].cuda(non_blocking=True)logits = model(x)prob = torch.sigmoid(logits).squeeze(-1) # [B]preds_local.append(prob)labels_local.append(y.float()) pred = torch.cat(preds_local, dim=0)lab = torch.cat(labels_local, dim=0) # 变长安全汇总pred_all, lab_all = gather_preds_labels(pred, lab) if dist.get_rank() == 0:from sklearn.metrics import roc_auc_score, average_precision_scoreauc = roc_auc_score(lab_all.numpy(), pred_all.numpy())ap = average_precision_score(lab_all.numpy(), pred_all.numpy())print(f"[Global] AUC={auc:.4f} AP={ap:.4f}")
经验总结
最后定位是 分布式汇总指标时,label 与 pred 在 all_gather
后发生错位(不同 rank 的尾批大小不同、或拼接顺序不一致),造成以错配数据计算的 AUC。本文完整复盘,并给出可直接复用的“变长安全汇总模板”