当前位置: 首页 > news >正文

MoviiGen1.1模型脚本调用

MoviiGen1.1模型脚本调用
在分布式训练场景下,通过Python脚本设置环境变量(如PYTHONPATH)并结合torchrun启动多进程训练,需遵循以下流程和代码规范:


1. torchrun 命令解析

  • --nproc_per_node=8:表示当前节点使用8个GPU进程(每个进程对应一个GPU)。
  • PYTHONPATH=.:在命令行中临时设置环境变量,将当前目录(.)加入Python模块搜索路径。但若需在脚本内动态设置,应使用os模块。

2. Python脚本内设置环境变量

使用os.environ在脚本开头动态设置环境变量(如PYTHONPATH):

import os
# 设置PYTHONPATH为当前目录
os.environ["PYTHONPATH"] = "."  # 或追加路径:os.environ["PYTHONPATH"] += os.pathsep + "."

注意:此设置仅对当前进程有效,在分布式训练中需确保所有进程同步设置。


3. 分布式训练所需环境变量

torchrun会自动为每个进程设置以下关键变量:

  • LOCAL_RANK:当前进程在本机的GPU编号(0~7)。
  • RANK:全局进程编号(跨所有节点)。
  • WORLD_SIZE:总进程数(此处为8)。
    在脚本中通过os.environ获取这些变量:
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

4. 完整分布式训练初始化示例

import os
import torch
import torch.distributed as dist# 1. 设置环境变量(如PYTHONPATH)
os.environ["PYTHONPATH"] = "."  # 添加当前目录到模块路径# 2. 获取torchrun自动分配的环境变量
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])# 3. 初始化分布式进程组
dist.init_process_group(backend="nccl",  # GPU通信后端init_method="env://",  # 从环境变量读取MASTER_ADDR/MASTER_PORT
)
torch.cuda.set_device(local_rank)  # 绑定当前进程到指定GPU# 4. 构建模型与数据加载器
model = build_model().cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=64, sampler=train_sampler)# 5. 训练逻辑(确保仅rank0保存模型)
for epoch in range(epochs):for batch in dataloader:outputs = model(batch)...if global_rank == 0:  # 仅主进程保存torch.save(model.module.state_dict(), "model.pth")

5. 注意事项

  • 作用域问题os.environ设置的变量仅对当前进程有效。若需全局生效,需在命令行或操作系统中设置。
  • 多节点训练:若涉及多机(如2台x8 GPU),需额外指定--nnodes=2 --node_rank={0,1} --master_addr=<主节点IP>
  • 替代方案:使用.env文件配合python-dotenv管理环境变量(适合复杂配置):
    from dotenv import load_dotenv
    load_dotenv(".env")  # 加载.env文件中的变量
    

通过以上步骤,可在脚本内动态设置环境变量并适配torchrun的分布式训练流程,确保多进程协同工作的正确性。

http://www.dtcms.com/a/330200.html

相关文章:

  • C语言队列的实现
  • AUTOSAR进阶图解==>AUTOSAR_SWS_TTCANInterface
  • 开发避坑指南(25):MySQL不支持带有limit语句的子查询的解决方案
  • 【学习嵌入式day23-Linux编程-文件IO】
  • imx6ull-驱动开发篇22——Linux 时间管理和内核定时器
  • 力扣top100(day02-04)--二叉树 01
  • 18.10 SQuAD数据集实战:5步高效获取与预处理,BERT微调避坑指南
  • 数据分析可视化学习总结(美妆2)
  • Python解包技巧全解析
  • Python 基础语法(一)
  • 多处理器技术:并行计算的基石与架构演进
  • 疯狂星期四文案网第38天运营日记
  • 继《念念有词》后又一作品《双刃》开播 马来西亚新人演员业文Kevin挑战多面角色引期待
  • CF每日3题(1600)
  • element-ui 时间线(timeLine)内容分成左右两侧
  • npm run dev 的作用
  • Unity_2D动画
  • 游戏盾的安全作用
  • RK3568嵌入式音视频硬件编解码4K 60帧 rkmpp FFmpeg7.1 音视频开发
  • Celery+RabbitMQ+Redis
  • Traceroute命令使用大全:从原理到实战技巧
  • IPC Inter-Process Communication(进程间通信)
  • 2小时构建生产级AI项目:基于ViT的图像分类流水线(含数据清洗→模型解释→云API)(第十七章)
  • 基于Supervision工具库与YOLOv8模型的高效计算机视觉任务处理与实践
  • 1.Cursor快速入门与配置
  • Multisim的使用记录
  • GQA:从多头检查点训练广义多查询Transformer模型
  • 蒙以CourseMaker里面的录屏功能真的是完全免费的吗?
  • C#标签批量打印程序开发
  • Redis 键扫描优化:从 KEYS 到 SCAN 的优雅升级