当前位置: 首页 > news >正文

Megatron-LM(模型并行)

Megatron-LM: Training Multi-Billion Parameter Language Models Using
Model Parallelism

1. 技术设计原则

Megatron-LM 提出轻量级层内模型并行,无需定制编译器或修改框架,仅通过在 PyTorch 原生代码中插入少量通信操作(如all-reduce)实现,且与流水线模型并行正交互补,可灵活组合。

2.背景:矩阵分块计算

参考:https://www.bilibili.com/video/BV1HdXtY9EuF/?share_source=copy_web&vd_source=0f3d85b09673431159069a2a9a3da50c

在这里插入图片描述
矩阵XXXYYY相乘,即计算XYXYXY,有两种分块运算方式:

  1. YYY拆分为[Y1,Y2][Y_1,Y_2][Y1,Y2]XXX不变
    XmnYnk=Xmn[Y1nk2,Y2nk2]=[XY1,XY2]X^{mn}Y^{nk}=X^{mn}[Y_1^{n\frac{k}{2}},Y_2^{n\frac{k}{2}}]=[XY_1,XY_2]XmnYnk=Xmn[Y1n2k,Y2n2k]=[XY1,XY2]
  2. YYY拆分为 [Y1Y2]\begin{bmatrix} Y_1 \\ Y_2 \end{bmatrix}[Y1Y2],把XXX拆分为[X1,X2][X_1,X_2][X1,X2]
    XmnYnk=[X1mn2,X2mn2]×[Y1n2kY2n2k]=(X1Y1)mk+(X2Y2)mkX^{mn}Y^{nk}= \begin{bmatrix}X_1^{m \frac{n}{2}},X_2^{m \frac{n}{2}} \end{bmatrix} \times \begin{bmatrix} Y_1^{\frac{n}{2} k} \\ Y_2^{\frac{n}{2} k} \end{bmatrix} = (X_1Y_1)^{mk} + (X_2Y_2)^{m k}XmnYnk=[X1m2n,X2m2n]×[Y12nkY22nk]=(X1Y1)mk+(X2Y2)mk

3. 关键模块并行化实现

这一部分的图解为作者自己根据理解画的,如果有错误请指正

(1)前馈网络层(MLP)

公式:FFN(X)=σ(XA)BFFN(X)=\sigma(XA)BFFN(X)=σ(XA)B
设序列长度为lll,隐藏层维度为ddd,前馈网络的隐藏层维度为dFFNd_{FFN}dFFNA∈Rd×dFFN,B∈RdFFN×dA \in \mathbb{R}^{d \times d_{FFN}}, B \in \mathbb{R}^{d_{FFN} \times d}ARd×dFFN,BRdFFN×d
张量并行-前馈网络层

权重矩阵拆分策略:
第一层线性层的权重矩阵按列拆分(A=[A1,A2]A=[A_1,A_2]A=[A1,A2]

  • 使 GeLU 非线性激活可在各 GPU 上独立计算,避免中间同步;
  • 即使得GeLU(XA)=[GeLU(XA1),⋯,GeLU(XAn)]\text{GeLU}(XA)=[ \text{GeLU}(XA_1),\cdots ,\text{GeLU}(XA_n)]GeLU(XA)=[GeLU(XA1),,GeLU(XAn)]

第二层线性层的权重矩阵按行拆分,直接接收 GeLU 输出

  • 在前向传播时仅需对第二层的输出做一次 All-Reduce 聚合输出。
  • 在反向传播时仅需在返回到输入时做一次 All-Reduce 聚合梯度。

通信优化:整个 MLP 模块仅需 2 次 all-reduce 操作(前向1次、反向1次),无额外同步点。

(2)多头注意力(Multi-Head Attention)模块

在这里插入图片描述
注意力头拆分:

  • 将 Q、K、V 对应的权重矩阵按拆分,每个 GPU 负责部分注意力头的计算,无需中间通信;
    注意力输出层权重按拆分,直接接收并行计算结果,仅需在反向传播聚合梯度。

优势:充分利用注意力头的天然并行性,每个 GPU 仅处理部分头的计算,降低单设备内存压力。

(3)输入层与输出层优化

张量并行-输入层
输入嵌入:

  • 按词汇表维度列拆分嵌入矩阵(E=[E1,E2]E=[E_1,E_2]E=[E1,E2]
  • 通过 g 算子(前向 all-reduce)聚合结果,避免单 GPU 存储完整词汇表。

在这里插入图片描述

输出层与损失计算:

  • 融合最终线性层的输出与交叉熵损失计算,直接在各 GPU 上计算局部损失后聚合
  • 无需传输大规模 logits,减少通信量从b×s×vb×s×vb×s×vb×sb×sb×s
  • bbb 为批次大小、sss 为序列长度、vvv 为词汇表大小

4.混合并行策略(模型+数据并行)

混合并行策略
GPU分组:

  • 将 GPU 划分为模型并行组(如8个GPU一组,共同承载一个模型)和数据并行组(不同模型并行组中同位置 GPU 组成,负责梯度同步)
  • 总 GPU 数 = 模型并行度 × 数据并行度
http://www.dtcms.com/a/359287.html

相关文章:

  • 【ACP】2025-最新-疑难题解析- 练习二汇总
  • STFT和梅尔频谱图
  • 项目管理的关键成功因素
  • 119、【OS】【Nuttx】【周边】效果呈现方案解析:变量展开
  • 【从零开始java学习|第十篇】面向对象
  • 【Blender】二次元人物制作【一】:二次元角色头部建模
  • Gray Code (格雷码)
  • 2025.8.30项目二基于UDP的TFTP文件传输
  • 【ICO】快速制作ICON教材/使用icofx3快速制作ico
  • 【多项式】快速沃尔什变换 (FWT)
  • 复现 RoboDK 机器人校准功能(以Staubli TX2‑90L / TX200机械臂为测试对象)
  • 关于铭飞平台企业官网模板使用中常到的问题、企业官网的百度认证以及IDEA编辑启动器的快捷方法/Apipost本地和云端没法同步的问题解决
  • 如何改变传统教育的消费习惯-第三代结束-第四代开启
  • 数值分析——数据误差对函数值的影响
  • 数据治理进阶——26页如何进行数据治理【附全文阅读】
  • 项目管理方法论有哪些流派
  • TuringComplete游戏攻略(一、基础逻辑电路)
  • Python(五)Python_C API详细
  • 嵌入式Linux输入子系统驱动开发
  • [光学原理与应用-332]:ZEMAX - 序列模式与非序列模式的本质、比较
  • FPGA 实现FOC 无刷电机控制器
  • 电子健康记录风险评分与多基因风险评分的互补性与跨系统推广性研究
  • 洛谷 P1395 会议 -普及/提高-
  • 吴恩达机器学习(四)
  • 10. 函数和匿名函数(二)
  • 深入理解 shared_ptr 与 weak_ptr:访问控制与线程安全
  • 广东省省考备考(第九十天8.30)——判断推理(第十节课)
  • Java多线程初阶
  • C++讲解---如何设计一个类
  • 防火墙技术(三):状态检测和会话机制