当前位置: 首页 > news >正文

二、TorchRec中的分片

TorchRec中的分片


文章目录

  • TorchRec中的分片
  • 前言
  • 一、Planner
  • 二、EmbeddingTable 的分片
    • TorchRec 中所有可用的分片类型列表
  • 三、使用 TorchRec 分片模块进行分布式训练
    • TorchRec 在三个主要阶段处理此问题
  • 四、DistributedModelParallel(分布式模型并行)
  • 总结


前言

  • 我们来了解TorchRec架构中是如何分片的

一、Planner

  • TorchRec planner 帮助确定模型的最佳分片配置。

  • 它评估嵌入表分片的多种可能性,并优化性能。

  • planner 执行以下操作:

    • 评估硬件的内存约束。
    • 根据内存获取(例如嵌入查找)估算计算需求。
    • 解决特定于数据​​的因素。
    • 考虑其他硬件细节,例如带宽,以生成最佳分片计划。

二、EmbeddingTable 的分片

  • TorchRec sharder 为各种用例提供了多种分片策略,我们概述了一些分片策略及其工作原理,以及它们的优点和局限性。通常,我们建议使用 TorchRec planner 为您生成分片计划,因为它将为模型中的每个嵌入表找到最佳分片策略。
  • 每个分片策略都确定如何进行表拆分、是否应拆分表以及如何拆分、是否保留某些表的一个或几个副本等等。分片结果中的每个表片段,无论是一个嵌入表还是其中的一部分,都称为分片。
  • 可视化 TorchRec 中提供的不同分片方案下表分片的放置

不同分片方案下表分片的放置

TorchRec 中所有可用的分片类型列表

  • 表式 (TW):顾名思义,嵌入表作为一个整体保留并放置在一个 rank 上。
  • 列式 (CW):表沿 emb_dim 维度拆分,例如,emb_dim=256 拆分为 4 个分片:[64, 64, 64, 64]。
  • 行式 (RW):表沿 hash_size 维度拆分,通常在所有 rank 之间均匀拆分。
  • 表式-行式 (TWRW):表放置在一个主机上,在该主机上的 rank 之间进行行式拆分。
  • 网格分片 (GS):表是 CW 分片的,每个 CW 分片都以 TWRW 方式放置在主机上。
  • 数据并行 (DP):每个 rank 保留表的副本。

分片后,模块将转换为它们自身的分片版本,在 TorchRec 中称为 ShardedEmbeddingCollectionShardedEmbeddingBagCollection。这些模块处理输入数据的通信、嵌入查找和梯度。

三、使用 TorchRec 分片模块进行分布式训练

  • 有许多可用的分片策略,我们如何确定使用哪一个?
    • 每种分片方案都有相关的成本,这与模型大小和 GPU 数量相结合,决定了哪种分片策略最适合模型。
  • 在没有分片的情况下,每个 GPU 保留嵌入表的副本 (DP),主要成本是计算,其中每个 GPU 在前向传递中查找其内存中的嵌入向量,并在后向传递中更新梯度。
  • 使用分片时,会增加通信成本:
    • 每个 GPU 都需要向其他 GPU 请求嵌入向量查找,并通信计算出的梯度。这通常被称为 all2all 通信。
    • 在 TorchRec 中,对于给定 GPU 上的输入数据,我们确定数据每个部分的嵌入分片所在的位置,并将其发送到目标 GPU。
    • 然后,目标 GPU 将嵌入向量返回给原始 GPU。在后向传递中,梯度被发送回目标 GPU,并且分片会通过优化器进行相应的更新。
  • 如上所述,分片需要我们通信输入数据和嵌入查找。

TorchRec 在三个主要阶段处理此问题

我们将此称为分片嵌入模块前向传递,该传递用于 TorchRec 模型的训练和推理

  • 特征 All to All / 输入分布 (input_dist)

    • 将输入数据(以 KeyedJaggedTensor 的形式)通信到包含相关嵌入表分片的适当设备
  • 嵌入查找

    • 使用特征 all to all 交换后形成的新输入数据查找嵌入
  • 嵌入 All to All/输出分布 (output_dist)

    • 将嵌入查找数据通信回请求它的适当设备(根据设备接收到的输入数据)
  • 后向传递执行相同的操作,但顺序相反。

四、DistributedModelParallel(分布式模型并行)

  • 以上所有内容最终汇集成 TorchRec 用于分片和集成计划的主要入口点。
  • 在高层次上,DistributedModelParallel 执行以下操作:
    • 通过设置进程组和分配设备类型来初始化环境。
    • 如果没有提供 sharder,则使用默认的 sharder,默认 sharder 包括 EmbeddingBagCollectionSharder
    • 接收提供的分片计划,如果未提供,则生成一个。
    • 创建模块的分片版本,并用它们替换原始模块,例如,将 EmbeddingCollection 转换为 ShardedEmbeddingCollection
    • 默认情况下,使用 DistributedDataParallel 包装 DistributedModelParallel,使模块既是模型并行又是数据并行。

总结

  • 对TorchRec中的分块策略进行了解。

相关文章:

  • 智能检索知识库​
  • 从入门到实战!Vue-router 的深度探索与高效应用
  • 数据结构与算法之ACM Fellow-算法4.3 最小生成树
  • docx文档转为pdf文件响应前端
  • 01-算法打卡-数组-二分查找-leetcode(704)-第一天
  • 两大奇妙的波-机械波-电磁波
  • 3D打印革新制造范式:CASAIM 3D打印解决方案
  • redis的基本使用
  • 大模型day1 - 什么是GPT
  • freecad内部python来源 + pip install 装包
  • 应用安全系列之四十五:日志伪造(Log_Forging)之三
  • DeepSeek实战:如何用AI工具提升销售转化率?
  • newspaper公共库获取每个 URL 对应的新闻内容,并将提取的新闻正文保存到一个文件中
  • 数字集成电路中时延不可综合与时间单位介绍
  • 用实体识别模型提取每一条事实性句子的关键词(实体),并保存到 JSON 文件中
  • JVM不同环境不同参数配置文件覆盖
  • C++中作用域(public,private,protected
  • CSS 过渡与变形:让交互更丝滑
  • STM32中Hz和时间的转换
  • context上下文(一)
  • 做个网站要钱吗/论坛seo教程
  • 我的世界充值网站怎么做/seo平台有哪些
  • 道外网站建设/千锋教育培训机构可靠吗
  • 夏天做哪些网站致富/长沙百度推广运营公司
  • 网站首页图片轮转/外包服务公司
  • 做旅游平台网站找哪家好/搜索引擎调词工具