当前位置: 首页 > news >正文

PyTorch 分布式训练(Distributed Data Parallel, DDP)简介

PyTorch 分布式训练(Distributed Data Parallel, DDP)

一、DDP 核心概念

torch.nn.parallel.DistributedDataParallel

1. DDP 是什么?

Distributed Data Parallel (DDP) 是 PyTorch 提供的分布式训练接口,DistributedDataParallel相比 DataParallel 具有以下优势:

  • 多进程而非多线程:避免 Python GIL 限制
  • 更高的效率:每个 GPU 有独立的进程,减少通信开销
  • 更好的扩展性:支持多机多卡训练
  • 更均衡的负载:无主 GPU 瓶颈问题

2. 核心组件

  • 进程组 (Process Group):管理进程间通信
  • NCCL 后端:NVIDIA 优化的 GPU 通信库
  • Ring-AllReduce:高效的梯度同步算法

在这里插入图片描述

二、完整 DDP 训练 Demo

  • 官方DDP Dem参考

1. 基础训练脚本 (ddp_demo.py)

import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torch.cuda.amp import GradScaler

def setup(rank, world_size):
    """初始化分布式环境"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    """清理分布式环境"""
    dist.destroy_process_group()

class SimpleModel(nn.Module):
    """简单的CNN模型"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc = nn.Linear(9216, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        return self.fc(x)

def prepare_dataloader(rank, world_size, batch_size=32):
    """准备分布式数据加载器"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return loader

def train(rank, world_size, epochs=2):
    """训练函数"""
    setup(rank, world_size)
    
    # 设置当前设备
    torch.cuda.set_device(rank)
    
    # 初始化模型、优化器等
    model = SimpleModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    optimizer = optim.Adam(ddp_model.parameters())
    scaler = GradScaler()  # 混合精度训练
    criterion = nn.CrossEntropyLoss()
    train_loader = prepare_dataloader(rank, world_size)
    
    for epoch in range(epochs):
        ddp_model.train()
        train_loader.sampler.set_epoch(epoch)  # 确保每个epoch有不同的shuffle
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(rank), target.to(rank)
            
            optimizer.zero_grad()
            
            # 混合精度训练
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                output = ddp_model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            if batch_idx % 100 == 0:
                print(f"Rank {rank}, Epoch {epoch}, Batch {batch_idx}, Loss {loss.item():.4f}")
    
    cleanup()

if __name__ == "__main__":
    # 单机多卡启动时,torchrun会自动设置这些环境变量
    rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    train(rank, world_size)

2. 启动训练

使用 torchrun 启动分布式训练(推荐 PyTorch 1.9+):

# 单机4卡训练
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12355 ddp_demo.py

3. 关键组件解析

3.1 分布式数据采样 (DistributedSampler)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
  • 确保每个 GPU 处理不同的数据子集
  • 自动处理数据分片和 epoch 间的 shuffle
3.2 模型包装 (DDP)
ddp_model = DDP(model, device_ids=[rank])
  • 自动处理梯度同步
  • 透明地包装模型,使用方式与普通模型一致
3.3 混合精度训练 (AMP)
scaler = GradScaler()
with torch.autocast(device_type='cuda', dtype=torch.float16):
    # 前向计算
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
  • 减少显存占用,加速训练
  • 自动管理 float16/float32 转换

三、DDP 最佳实践

  1. 数据加载

    • 必须使用 DistributedSampler
    • 每个 epoch 前调用 sampler.set_epoch(epoch) 保证 shuffle 正确性
  2. 模型保存

    if rank == 0:  # 只在主进程保存
        torch.save(model.state_dict(), "model.pth")
    
  3. 多机训练

    # 机器1 (主节点)
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
    # 机器2
    torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr=IP1 --master_port=12355 ddp_demo.py
    
  4. 性能调优

    • 调整 batch_size 使各 GPU 负载均衡
    • 使用 pin_memory=True 加速数据加载
    • 考虑梯度累积减少通信频率

四、常见问题解决

  1. CUDA 内存不足

    • 减少 batch_size
    • 使用梯度累积
    for i, (data, target) in enumerate(train_loader):
        if i % 2 == 0:
            optimizer.zero_grad()
        # 前向和反向...
        if i % 2 == 1:
            optimizer.step()
    
  2. 进程同步失败

    • 检查所有节点的 MASTER_ADDRMASTER_PORT 一致
    • 确保防火墙开放相应端口
  3. 精度问题

    • 混合精度训练时出现 NaN:调整 GradScaler 参数
    scaler = GradScaler(init_scale=1024, growth_factor=2.0)
    

相关文章:

  • 快速入门 JSON 数据格式
  • wireshark开启对https密文抓包
  • 【工具使用-编译器】VScode(Ubuntu)使用
  • 【Android15 ShellTransitions】(九)结束动画+Android原生ANR问题分析
  • 深度学习篇---回归分类任务的损失函数
  • Git(八)如何在同一台电脑登录两个Git
  • 45 55跳跃游戏解题记录
  • static方法使用bean的方式
  • Nodejs上传文件的问题
  • 【JavaScript】JavaScript Promises实践指南
  • UE5 UE4 右键/最大化-菜单-不显示/闪/黑色/黑屏--修复方法
  • DBeaver配置postgresql数据库连接驱动
  • PHP开发者2025生存指南
  • Android 蓝牙/Wi-Fi通信协议之:经典蓝牙(BT 2.1/3.0+)介绍
  • CentOS 7 安装 EMQX (MQTT)
  • IP第一次笔记
  • 学习中学习的小tips(主要是学习苍穹外卖的一些学习)
  • MetInfo6.0.0目录遍历漏洞原理分析
  • Kubernetes 结点排水卡住的原因及解决方案
  • Python 学习路线推荐
  • 耶路撒冷发生山火,以防长宣布紧急状态
  • 中国强镇密码丨洪泽湖畔的蒋坝,如何打破古镇刻板印象
  • 新希望一季度归母净利润4.45亿,上年同期为-19.34亿
  • “乐购浦东”消费券明起发放,多个商家同期推出折扣促销活动
  • 豆神教育:2024年净利润1.37亿元,同比增长334%
  • 十四届全国人大常委会举行第四十三次委员长会议 ,听取有关草案和议案审议情况汇报