当前位置: 首页 > news >正文

【教程】PyTorch多机多卡分布式训练的参数说明 | 附通用启动脚本

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

目录

torchrun

一、什么是 torchrun

二、torchrun 的核心参数讲解

三、torchrun 会自动设置的环境变量

四、torchrun 启动过程举例

机器 A(node_rank=0)上运行

机器 B(node_rank=1)上运行

五、小结表格

PyTorch

一、背景回顾

二、init_process_group

三、脚本中通常的典型写法

通用启动脚本


torchrun 与 torch.multiprocessing.spawn 的对比可以看这篇:

【知识】torchrun 与 torch.multiprocessing.spawn 的对比

torchrun

一、什么是 torchrun

torchrun 是 PyTorch 官方推荐的分布式训练启动器,它的作用是:

  • 启动 多进程分布式训练(支持多 GPU,多节点)

  • 自动设置每个进程的环境变量

  • 协调节点之间建立通信

二、torchrun 的核心参数讲解

torchrun \--nnodes=2 \--nproc_per_node=2 \--node_rank=0 \--master_addr=192.168.5.228 \--master_port=29400 \xxx.py

🔹 1. --nnodes(Number of Nodes)

  • 表示参与训练的总机器数

  • 你有几台服务器,就写几。

  • 在分布式训练中,一个 node 就是一台物理或虚拟的主机。

  • node的编号从0开始。

✅ 例子:你用 2 台机器 → --nnodes=2


🔹 2. --nproc_per_node(Processes Per Node)

  • 表示每台机器上要启动几个训练进程。

  • 一个进程对应一个 GPU,因通常设置为你机器上要用到的GPU数。

  • 因此,整个分布式环境下,总训练进程数 = nnodes * nproc_per_node

✅ 例子:每台机器用了 2 张 GPU → --nproc_per_node=2


🔹 3. --node_rank

  • 表示当前机器是第几台机器

  • 从 0 开始编号,必须每台机器都不同!

✅ 例子:

机器 IPnode_rank
192.168.5.2280
192.168.5.2291

🔹 4. --master_addr--master_port

  • 指定主节点的 IP 和端口,用于 rendezvous(进程对齐)和通信初始化。

  • 所有机器必须填写相同的值!

✅ 建议:

  • master_addr 就是你指定为主节点的那台机器的 IP

  • master_port 选一个未被占用的端口,比如 29400

三、torchrun 会自动设置的环境变量

当用 torchrun 启动后,它会自动给每个进程设置这些环境变量

环境变量含义
RANK当前进程在全局中的编号(0 ~ world_size - 1)
LOCAL_RANK当前进程在本机中的编号(0 ~ nproc_per_node - 1)
WORLD_SIZE总进程数 = nnodes * nproc_per_node

你可以在训练脚本里用 os.environ["RANK"] 来读取这些信息:

import os
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])

示例分配图:

四、torchrun 启动过程举例

假设:

  • 有 2 台机器

  • 每台机器有 2 个 GPU

  • 总共会启动 4 个进程

机器 A(node_rank=0)上运行

torchrun \--nnodes=2 \--nproc_per_node=2 \--node_rank=0 \--master_addr=192.168.5.228 \--master_port=29400 \xxx.py

机器 B(node_rank=1)上运行

torchrun \--nnodes=2 \--nproc_per_node=2 \--node_rank=1 \--master_addr=192.168.5.228 \--master_port=29400 \xxx.py

torchrun 给每个进程编号的顺序(分配 RANK / LOCAL_RANK)

torchrun 按照每台机器上 node_rank 的顺序,并在每台机器上依次启动 LOCAL_RANK=0, 1, ..., n-1,最后合成 RANK。

RANK = node_rank × nproc_per_node + local_rank

Step 1:按 node_rank 升序处理(node 0 → node 1)

Step 2:每个 node 内部从 local_rank=0 开始递增


本质上:torchrun 是主从结构调度的

  • 所有 node 启动后,都会和 master_addr 通信。

  • master 会统一收集所有 node 的状态。

  • 每个 node 根据你给的 node_rank 自行派生 local_rank=0~n-1

  • 所有节点通过 RANK = node_rank * nproc_per_node + local_rank 得到自己的全局编号。

这个机制是 可预测、可控、可复现 的。


📦 node_rank=0 (机器 1)
    ├── local_rank=0 → RANK=0
    └── local_rank=1 → RANK=1

📦 node_rank=1 (机器 2)
    ├── local_rank=0 → RANK=2
    └── local_rank=1 → RANK=3

最终分配:

Node RankLocal RankGlobal Rank (RANK)使用 GPU
0000
0111
1020
1131

五、小结表格

参数作用设置方式
--nnodes总节点数你写在命令里
--nproc_per_node每台节点的进程数(= GPU 数)你写在命令里
--node_rank当前机器编号(0开始)每台机器唯一
--master_addr主节点 IP(所有节点需一致)你设置
--master_port主节点端口(所有节点需一致)你设置
RANK当前进程在所有进程中的编号torchrun 自动设置
LOCAL_RANK当前进程在本节点上的编号torchrun 自动设置
WORLD_SIZE总进程数 = nnodes * nproc_per_node自动设置

PyTorch

PyTorch 的分布式通信是如何通过 init_process_grouptorchrun 生成的环境变量配合起来工作的。

一、背景回顾

你已经用 torchrun 启动了多个训练进程,并且 torchrun 为每个进程自动设置了这些环境变量:

变量名含义
RANK当前进程的全局编号(从 0 开始)
LOCAL_RANK本机上的编号(一般等于 GPU ID)
WORLD_SIZE总进程数
MASTER_ADDR主节点的 IP
MASTER_PORT主节点用于通信的端口

那么 这些变量是如何参与进程通信初始化的? 这就涉及到 PyTorch 的核心函数:


二、init_process_group

torch.distributed.init_process_group 是 PyTorch 初始化分布式通信的入口:

torch.distributed.init_process_group(backend="nccl",  # 或者 "gloo"、"mpi"init_method="env://",  # 通过环境变量读取设置
)

关键点:

  • backend="nccl":推荐用于 GPU 分布式通信(高性能)

  • init_method="env://":表示通过环境变量来初始化


你不需要自己设置 RANK / WORLD_SIZE / MASTER_ADDR,只要写:

import torch.distributed as distdist.init_process_group(backend="nccl", init_method="env://")

PyTorch 会自动去环境中读这些变量:

  • RANK → 当前进程编号

  • WORLD_SIZE → 总进程数

  • MASTER_ADDRMASTER_PORT → 主节点 IP 和端口

然后就能正确初始化所有通信进程。

三、脚本中通常的典型写法

import os
import torch# 初始化 PyTorch 分布式通信环境
torch.distributed.init_process_group(backend="nccl", init_method="env://")# 获取全局/本地 rank、world size
rank = int(os.environ.get("RANK", -1))
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))# 设置 GPU 显卡绑定
torch.cuda.set_device(local_rank)
device = torch.device("cuda")# 打印绑定信息
print(f"[RANK {rank} | LOCAL_RANK {local_rank}] Using CUDA device {torch.cuda.current_device()}: {torch.cuda.get_device_name(torch.cuda.current_device())} | World size: {world_size}")

这段代码在所有进程中都一样写,但每个进程启动时带的环境变量不同,所以最终 ranklocal_rankworld_size 就自然不同了。

通用启动脚本

#!/bin/bash# 设置基本参数
MASTER_ADDR=192.168.5.228           # 主机IP
MASTER_PORT=29400                   # 主机端口
NNODES=2                            # 参与训练的总机器数
NPROC_PER_NODE=2                    # 每台机器上的进程数# 所有网卡的IP地址,用于筛选
ALL_LOCAL_IPS=$(hostname -I)
# 根据本机 IP 配置通信接口
if [[ "$ALL_LOCAL_IPS" == *"192.168.5.228"* ]]; thenNODE_RANK=0                       # 表示当前机器是第0台机器IFNAME=ens1f1np1  mytorchrun=~/anaconda3/envs/dglv2/bin/torchrun
elif [[ "$ALL_LOCAL_IPS" == *"192.168.5.229"* ]]; thenNODE_RANK=1                       # 表示当前机器是第1台机器IFNAME=ens2f1np1mytorchrun=/opt/software/anaconda3/envs/dglv2/bin/torchrun
elseexit 1
fi# 设置 RDMA 接口
export NCCL_IB_DISABLE=0            # 是否禁用InfiniBand
export NCCL_IB_HCA=mlx5_1           # 使用哪个RDMA接口进行通信
export NCCL_SOCKET_IFNAME=$IFNAME   # 使用哪个网卡进行通信
export NCCL_DEBUG=INFO              # 可选:调试用
export GLOO_IB_DISABLE=0            # 是否禁用InfiniBand
export GLOO_SOCKET_IFNAME=$IFNAME   # 使用哪个网卡进行通信
export PYTHONUNBUFFERED=1           # 实时输出日志# 启动分布式任务
$mytorchrun \--nnodes=$NNODES \--nproc_per_node=$NPROC_PER_NODE \--node_rank=$NODE_RANK \--master_addr=$MASTER_ADDR \--master_port=$MASTER_PORT \cluster.py## 如果想获取准确报错位置,可以加以下内容,这样可以同步所有 CUDA 操作,错误不会“延迟触发”,你会看到确切是哪一行代码出了问题:
## CUDA_LAUNCH_BLOCKING=1 torchrun ...

相关文章:

  • 网盘文件下载功能需求分析与技术方案选择:全面解析与最佳实践
  • windows修改远程端口
  • OCP中的OCS operator介绍及应用示例
  • 如何将 Vue-FastAPI-Admin 项目的数据库从 SQLite 切换到 MySQL?
  • 量子纠缠物理本质、技术实现、应用场景及前沿研究
  • Web三漏洞学习(其一:文件上传漏洞)
  • 25.4.15学习总结
  • 代码随想录第18天:二叉树
  • 04-Seata 深度解析:从分布式事务原理到 Seata 实战落地
  • Arduino+ESP826601s模块连接阿里云并实现温湿度数据上报
  • 【leetcode hot 100 72】编辑距离
  • MCP认证难题破解指南
  • 单片机非耦合业务逻辑框架
  • Sentinel源码—2.Context和处理链的初始化二
  • (51单片机)LCD显示日期时间时钟(DS1302时钟模块教学)(LCD1602教程)
  • STM32提高篇: 以太网通讯
  • S06-Kep的跨通道传输
  • 二极管详解:特性参数、选型要点与分类
  • 【正点原子STM32MP257连载】第四章 ATK-DLMP257B功能测试——CAN、CAN FD测试 #FDCAN
  • Qt/C++学习系列之QTreeWidget的简单使用记录
  • 内蒙古公开宣判144件毁林毁草刑案,单起非法占用林地逾250亩
  • 迪卡侬回应出售中国业务30%股份传闻:始终扎根中国长期发展
  • 澎湃回声丨23岁小伙“被精神病”8年续:今日将被移出“重精”管理系统
  • 坚持科技创新引领,赢得未来发展新优势
  • 光明网评“泉州梦嘉商贸楼不到5年便成危楼”:监管是否尽职尽责?
  • 牛市早报|今年第二批810亿元超长期特别国债资金下达,支持消费品以旧换新