当前位置: 首页 > news >正文

PyTorch DDP 随机卡死复盘

PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last

一次真实的分布式训练“玄学卡死”:2 卡训练偶发在 epoch 尾部停住不动,GPU 利用率掉到 0%,日志无异常。最终定位是 DistributedSampler 使用不当 + drop_last=False + 忘记 set_epoch 引发各 rank 步数不一致,导致 allreduce 永久等待。

技术环境

OS:Ubuntu 22.04

Python:3.10.13

PyTorch:2.2.2 + CUDA 12.1(torch==2.2.2+cu121)

NCCL:2.18(系统自带,未自编译)

GPU:2×RTX 3090(24GB)

启动方式:torchrun --standalone --nproc_per_node=2 train.py

Bug 现象

训练随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi 显示两卡功耗接近空闲。

偶尔能看到 NCCL 打印(并不总出现):

NCCL WARN Reduce failed: … Async operation timed out

kill -SIGQUIT 打印 Python 栈后发现停在 反向传播的梯度 allreduce 上(DistributedDataParallel 内部)。

关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。

最小可复现(错误版)

问题点集中在 数据划分不均 + Sampler 误用:

shuffle=True 与 DistributedSampler 混用(会被忽略但容易误导)。

drop_last=False 时,最后一个小批的样本数在不同 rank 上可能不一致(当 len(dataset) 不是 world_size 的整数倍且某些数据被过滤/增强丢弃时尤其明显)。

每个 epoch 忘记调用 sampler.set_epoch(epoch),导致各 rank 的随机顺序不一致。

train_ddp_wrong.py —— 错误示例(请勿照抄)

import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler

class DummyDS(Dataset):
def init(self, N=1003): # 刻意设成非 world_size 整数倍
self.N = N
def len(self): return self.N
def getitem(self, i):
x = torch.randn(32, 3, 224, 224)
y = torch.randint(0, 10, (32,)) # 模拟有时会丢弃某些样本的增强(省略)
return x, y

def setup():
dist.init_process_group(“nccl”)
torch.cuda.set_device(int(os.environ[“LOCAL_RANK”]))

def main():
setup()
rank = dist.get_rank()
device = torch.device(“cuda”, int(os.environ[“LOCAL_RANK”]))
ds = DummyDS()

sampler = DistributedSampler(ds, shuffle=True, drop_last=False)  # ❌ drop_last=False
# ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)
loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)model = torch.nn.Linear(3*224*224, 10).to(device)
model = DDP(model, device_ids=[device.index])
opt = torch.optim.SGD(model.parameters(), lr=0.1)for epoch in range(5):# ❌ 忘记 sampler.set_epoch(epoch)for x, y in loader:x = x.view(x.size(0), -1).to(device)y = y.to(device)opt.zero_grad()loss = torch.nn.functional.cross_entropy(model(x), y)loss.backward()      # 🔥 偶发卡在这里(allreduce)opt.step()if rank == 0:print(f"epoch {epoch} done")dist.destroy_process_group()

if name == “main”:
main()

触发条件(满足一两个就可能复现):

len(dataset) 不是 world_size 的整数倍。

动态数据过滤/增强(例如有时返回 None 或丢样),导致各 rank 实际步数不同。

忘记 sampler.set_epoch(epoch),各 rank 洗牌次序不同。

drop_last=False,导致最后一个 batch 在各 rank 的样本数不同。

某些自定义 collate_fn 在“空 batch”时直接 continue。

排查步骤
1)先确认“各 rank 步数一致”

在训练 loop 里加统计(不要只在 rank0 打印):

from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):
steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)

或每个 rank 各自 print,检查是否相等

我的现象:有的 epoch,rank0 比 rank1 多 1–2 个 step。

2)开启 NCCL 调试

在启动前设置:

export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1

再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。

3)检查 Sampler 与 DataLoader 参数

DistributedSampler 必须搭配 sampler.set_epoch(epoch)。

DataLoader 里不要再写 shuffle=True。

若数据不可整除,优先 drop_last=True;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。

解决方案(修复版)
✅ 方案 A:严格对齐 Sampler 语义 + 丢最后不齐整的 batch

train_ddp_fixed.py —— 推荐修复

import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Dataset

class DummyDS(Dataset):
def init(self, N=1003): self.N=N
def len(self): return self.N
def getitem(self, i):
x = torch.randn(32, 3, 224, 224)
y = torch.randint(0, 10, (32,))
return x, y

def setup():
dist.init_process_group(“nccl”)
torch.cuda.set_device(int(os.environ[“LOCAL_RANK”]))

def main():
setup()
rank = dist.get_rank()
device = torch.device(“cuda”, int(os.environ[“LOCAL_RANK”]))

ds = DummyDS()
# 关键 1:使用 DistributedSampler,统一交给它洗牌
sampler = DistributedSampler(ds, shuffle=True, drop_last=True)  # ✅
# 关键 2:DataLoader 里不要再写 shuffle
loader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)model = torch.nn.Linear(3*224*224, 10).to(device)
ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False)  # 如无动态分支,关掉更稳更快
opt = torch.optim.SGD(ddp.parameters(), lr=0.1)for epoch in range(5):sampler.set_epoch(epoch)  # ✅ 关键 3:每个 epoch 设置不同随机种子for x, y in loader:x = x.view(x.size(0), -1).to(device, non_blocking=True)y = y.to(device, non_blocking=True)opt.zero_grad(set_to_none=True)loss = torch.nn.functional.cross_entropy(ddp(x), y)loss.backward()opt.step()if rank == 0:print(f"epoch {epoch} ok")dist.barrier()  # ✅ 收尾同步,避免 rank 提前退出
dist.destroy_process_group()

if name == “main”:
main()

✅ 方案 B:必须保留最后一批(学术场景)

如果确实不能 drop_last=True(例如小数据集),可考虑对齐 batch 大小:

Padding/Repeat:在 collate_fn 里把最后一批补齐到一致大小;

EvenlyDistributedSampler:自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。

示例(最简单的“循环补齐”):

class EvenSampler(DistributedSampler):
def iter(self):
# 先拿到原始 index,再做均匀补齐
indices = list(super().iter())
# 使得 len(indices) 可整除 num_replicas
rem = len(indices) % self.num_replicas
if rem != 0:
pad = self.num_replicas - rem
indices += indices[:pad] # 简单重复前几个样本
return iter(indices)

✅ 方案 C:降低“意外丢样”风险

自定义 collate_fn 不要在空 batch 时 return None 或直接 continue,而应抛异常或做补齐。

数据增强/过滤若可能丢样,务必在 Dataset 内重采样,保证 getitem 总是返回有效样本。

若模型里有条件分支可能不参与反向(导致“未使用参数”),

要么收敛后改为固定分支;

要么在 DDP 里开启 find_unused_parameters=True(但会更慢,且仍需确保步数一致)。

验证

修复后,连续训练 50+ 个 epoch 未再出现挂起;

加上 dist.barrier() 收尾,脚本结束更干净;

打开 NCCL_BLOCKING_WAIT=1 时也不再报超时。

避坑总结(Checklist)

一定要 sampler.set_epoch(epoch),确保各 rank 洗牌一致。

不要在 DataLoader 再写 shuffle=True(使用 DistributedSampler 时交给 sampler)。

尽量 drop_last=True,避免尾批大小不一致;若必须保留尾批,就补齐到等长。

保证各 rank 步数完全一致:collate 不能静默丢 batch;Dataset 不要“偶发返回 None”。

按需设置 DDP 参数:无动态分支时 find_unused_parameters=False 更稳更快。

开 NCCL 调试:NCCL_DEBUG=INFO、NCCL_ASYNC_ERROR_HANDLING=1、NCCL_BLOCKING_WAIT=1,排障高效。

收尾同步:退出前 dist.barrier(),避免某 rank 早退影响他人。

最简复现先做整除长度:把 len(dataset) 设为 k * world_size,观察是否立刻恢复。

以上是这次 DDP 卡死问题从现象 → 排查 → 解决的完整记录。这个坑非常高频,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗


文章转载自:

http://RXBVXOco.mcwrg.cn
http://mvd7lE6S.mcwrg.cn
http://S7cLPeki.mcwrg.cn
http://iRKANDAN.mcwrg.cn
http://a36jBraH.mcwrg.cn
http://mohUKh3P.mcwrg.cn
http://gqPR1jT1.mcwrg.cn
http://kS9j0sV2.mcwrg.cn
http://jCgLC1vD.mcwrg.cn
http://3k6R9vbx.mcwrg.cn
http://fntrbJHs.mcwrg.cn
http://Kn3W6QMy.mcwrg.cn
http://wflD0yhu.mcwrg.cn
http://K19SngAP.mcwrg.cn
http://zaXtnHjW.mcwrg.cn
http://KuetlyKP.mcwrg.cn
http://FyNG3zcD.mcwrg.cn
http://owhi6YVA.mcwrg.cn
http://s2mr5KaV.mcwrg.cn
http://mUQ5R2Eq.mcwrg.cn
http://85vupBWe.mcwrg.cn
http://fAoKYt8z.mcwrg.cn
http://4z9q52hl.mcwrg.cn
http://ySQkfWOP.mcwrg.cn
http://bKkIXsnm.mcwrg.cn
http://LI4AZKkh.mcwrg.cn
http://l4ob8lsF.mcwrg.cn
http://rrhouKQR.mcwrg.cn
http://P4ANOwQk.mcwrg.cn
http://ooRPQKia.mcwrg.cn
http://www.dtcms.com/a/366984.html

相关文章:

  • JVM 类加载全过程
  • 关于IDEA构建Gradle项目时报错“contentRootData“ is null的一次排查
  • devcpp 5.11的详细安装步骤
  • 高效菜单管理页面:一键增删改查
  • jmeter压测工具使用详情
  • finally 与 return的执行顺序
  • Java String vs StringBuilder vs StringBuffer:一个性能优化的探险故事
  • 邦芒干货:新入职场的人必须要知道的三大事情
  • JY-H818|科智立RFID高频读写器产品参数解析
  • LVDS系列27:Xilinx 7系 OSERDESE2原语(三)
  • [晕事]今天做了件晕事91,glibc,rand之前必须设置种子
  • C语言内存精讲系列(七):深入解析 x86 实模式
  • 远场代码学习_FDTD_farfield
  • 五、插值与拟合
  • 今天我们继续学习Linux中的shell脚本流程控制内容
  • 大模型微调之LORA核心逻辑
  • React笔记_组件之间进行数据传递
  • 《Java餐厅的待客之道:BIO, NIO, AIO三种服务模式的进化》
  • 【OpenHarmony文件管理子系统】文件访问接口解析
  • sealos部署k8s
  • (C题|NIPT 的时点选择与胎儿的异常判定)2025年高教杯全国大学生数学建模国赛解题思路|完整代码论文集合
  • 25高教社杯数模国赛【C题国一学长思路+问题分析】第二弹
  • 数学建模25c
  • 互联网大厂Java面试场景与问题解答
  • LeetCode 刷题【64. 最小路径和】
  • Rust+slint实现一个登录demo
  • Rust 文件操作终极实战指南:从基础读写到进阶锁控,一文搞定所有 IO 场景
  • 代码随想录算法训练营第二十八天 | 买卖股票的最佳实际、跳跃游戏、K次取反后最大化的数组和
  • 2025全国大学生数学建模C题保姆级思路模型(持续更新):NIPT 的时点选择与胎儿的异常判定
  • 2025反爬虫之战札记:从robots.txt到多层防御的攻防进化史