当前位置: 首页 > news >正文

分布式机器学习之流水线并行GPipe:借助数据并行来实现模型并行计算

分布式机器学习之流水线并行GPipe:借助数据并行来实现模型并行计算

本文开始介绍流水线并行(Pipeline Parallel),流水线并行可以算是模型并行的一种,不同于之前的张量并行对一层内的张量权重进行切分,主要适用于 Transformer 类模型,流水线并行是在整个模型所有层之间进行切分,可以适用于任意模型。但是由于需要切分多组尽可能负载均衡的模型层,流水线并行最适合由同样的层堆积起来的序列式模型。比如 LLM,是由多个相同的 decoder layer 堆积而成,就非常适合用 PP。对于结构比较复杂的,非对称类型的模型,如果要用 PP,需要仔细地分析测试,找到负载最均衡的切分方式。本文首先介绍一篇流水线并行中比较简单、常用的方案:GPipe。

切分 micro-batch 来实现模型并行计算

当模型尺寸越来越大,单个 GPU 的显存不足以放下整个模型,我们考虑将模型拆分到多个 GPU 上,除了之前介绍的张量并行在层内切分权重矩阵之外,我们还可以将模型按层切分成多组,每一组分别放到不同的 GPU 上进行计算。

我们知道,深度学习一般使用 BP 算法来计算梯度、更新参数,整个训练过程分为前向计算和反向传播两个阶段,前向计算是从模型输入到输出,然后计算 loss,反向传播则是从 loss 到输出再到各层参数计算梯度,两个阶段的计算都有一定的依赖关系,比如模型第 iii 层反向传播的计算依赖于第 i+1i+1i+1 层的计算结果。当我们按层切分模型,这种依赖关系会对整个系统的计算调度和并行性造成一些挑战。如何更好地进行计算调度,也正是一系列流水线并行研究工作的核心。

下图展示了按层切分的方式以及前向反向计算的依赖关系。

在这里插入图片描述

由于计算依赖关系的存在,在实际进行训练时,前向过程中每一层的计算都必须等待前一层的计算完成,反向传播过程也是如此(只是依赖关系反了过来)。如下图所示,这就是对模型进行按层切分之后,实际计算的过程。

我们发现,按照下图的做法,在每一时刻,都只有一个 GPU 在进行计算,而其他 GPU 都处于闲置状态,也就是说,整个系统多个 GPU 间实际上完全没有进行并行计算(除了参数更新阶段)。只是利用了多 GPU 的显存来放置更大的模型,而没有用到多 GPU 的算力来加速计算。这种低效的调度策略,在图中直观地体现为 GPU bubble(空白部分)比较多,甚至反而 GPU 越多,每个 GPU 的计算利用率就越低。

在这里插入图片描述

如何设计调度策略,才能减少 GPU bubble,提升整体的计算效率呢?我们提到,之所以在将模型不同层切分到不同 GPU 上之后仍需进行顺序的计算,是因为前向计算和反向传播各自有一定的依赖关系,所以每一个 GPU 上的计算必须等到它依赖的前一个 GPU 计算完成,才能开始计算。如何才能打破这种依赖关系的限制呢?

GPipe 的方案是:将数据切分为更细粒度的 micro batch(微批次),从数据并行的角度解耦 GPU 计算任务的依赖关系,从而实现同一时刻多 GPU 的并行计算。具体来说,GPipe 将每一个数据批次切分为多个微批次(F0→[F0,0,F0,1,F0,2,F0,3]F_{0}\rightarrow [F_{0,0},F_{0,1},F_{0,2},F_{0,3}]F0[F0,0,F0,1,F0,2,F0,3]),这样,设备 1 在计算完 F0,0F_{0,0}F0,0 之后,F1,0F_{1,0}F1,0 计算所依赖的结果就已经有了,设备 2 就可以开始计算 F1,0F_{1,0}F1,0,而同时设备 1 可以继续计算下一个微批量 F0,1F_{0,1}F0,1,这样就实现了多 GPU 的并行计算。反向计算也是类似的。

模型不同层之间存在依赖关系,但是(同一批次内)不同数据之间的计算任务是没有依赖关系的。GPipe 正是观察到这一点,并通过更细粒度的数据批次划分,解耦计算任务的依赖关系,实现了多 GPU 并行计算。这样来看,GPipe 其实严格来说是一种 “模型并行+数据并行” 的并行计算方案。这种思路类似于 CPU 的流水线指令执行,因此这里称为流水线并行。

在这里插入图片描述

re-materialization

对于显存占用的优化,GPipe 采用了 re-materialization 技术,有时也称为重计算或梯度检查点技术。具体来说,就是在前向计算过程中,丢弃一部分中间计算的结果,节省一些显存,在反向传播需要这些中间结果计算梯度时,重新计算出来,是一种时间换空间的思想。要注意,这项技术并不是 Pipeline Parallel 特有的,在一般的模型训练中我们也经常用它来降低峰值显存占用。最早提出应该是 tianqi 的工作。接下来具体介绍下梯度检查点技术(图示非原创,忘记哪里找的了 😦 )。

模型训练时,机器学习框架会将前向过程中的激活值保存下来,用于反向传播时模型参数梯度的计算。一般情况下,我们会保存过程中的每一个激活值,这样计算梯度的速度是最快的。即如下图所示,图中第一行圆圈表示激活值,第二行表示梯度,箭头表示数值计算的依赖关系,圆圈被涂成紫色表示将其存在显存中。

在这里插入图片描述

在常规情况下,这没什么问题,是训练效率最高的方式。然而,当模型或者 batch size 比较大,显存捉襟见肘时,我们就希望在前向过程中将一些暂时不会用到激活值丢弃掉,从而节省一定的显存。

如下图所示,可以看到,最后一个节点梯度的计算仅依赖与输出值和最后一个激活值,之前过程中的激活值暂时用不到,可以在前向过程中丢弃掉。当计算倒数第二个节点的梯度时,依赖倒数第一个节点的梯度和倒数第二个节点的激活值,这时由于没有保存该激活值,我们需要重新执行一遍前向过程,计算该值。以此类推,虽然节省了大量的显存空间,但是计算每个节点梯度时我们都需要重新将其之前的网络整个前向一遍,带来的额外计算开销太大。

在这里插入图片描述

我们进一步希望既能节省一定的显存,又能保持耗时不要增加太多,有什么办法呢?这就该梯度检查点(Gradient Checkpointing)登场了。具体来说,我们按照一定策略保存前向过程中一些节点处的激活值,以便反向传播中重新计算某节点的激活值时不用从头开始,而可以从离他最近的有保存节点开始,从而减小耗时的增加。同时对于其他大部分节点的激活值,我们还是丢掉,以达到节省显存的目的。这个过程如下图所示。这样,整体上我们就能在时间开销和空间开销之间达到平衡。关于保存检查点的策略,最佳选择是将 n\sqrt{n}n 个节点标记为 checkpoint,这样每个节点最多重新计算一次,整体上相当于只多了一次前向传播。

在这里插入图片描述

模型训练过程中,中间激活占的显存大概是 O(N×L×D)\mathcal{O}(N\times L\times D)O(N×L×D),其中 N,L,DN,L,DN,L,D 分别表示 batch size、模型层数和模型维度数。因此,在 batch size 或模型很大的时候,尤其是 batch size 大的时候,节省的显存量是很可观的。

在 GPipe 中,正好适合将多个 GPU 边界处的中间结果作为检查点。

总结

流水线并行和张量并行都是模型并行的方法。其中张量并行是对层内的权重张量进行切分,主要针对 Transformer 类模型设计,流水线并行适用于所有模型,但需要仔细设计切分方案,尽量保证负载均衡。

GPipe 是流水线并行的经典工作,它通过数据并行解耦计算任务的依赖关系,实现了多 GPU 的并行计算。并通过梯度检查点技术来降低峰值显存占用。

http://www.dtcms.com/a/338590.html

相关文章:

  • JVM之Java内存区域与内存溢出异常
  • 微服务-06.微服务拆分-拆分原则
  • 117. 软件构建,拓扑排序,47. 参加科学大会,dijkstra算法
  • webpack》》Plugin 原理
  • VSCode 从安装到精通:下载安装与快捷键全指南
  • 视觉采集模块的用法
  • 企业知识管理革命:RAG系统在大型组织中的落地实践
  • 大数据数据库 —— 初见loTDB
  • 最新研究进展:2023-2025年神经机器翻译突破性成果
  • 【无标题】基于大数据+Python的共享单车骑行数据分析关系可视化 基于Spark+Hadoop的共享单车使用情况监测与数据可视化
  • AI 药物发现:化学分子到机器学习数值特征的转化——打通“化学空间”与“模型空间”关键路径
  • 大语言模型基本架构
  • 全网首发CentOS 7.6安装openGauss 6.0.2 LTS企业版(单机)
  • Linux------《零基础到联网:CentOS 7 在 VMware Workstation 中的全流程安装与 NAT 网络配置实战》
  • vue3实现实现手机/PC端录音:recorder-core
  • Apache IoTDB(4):深度解析时序数据库 IoTDB 在Kubernetes 集群中的部署与实践指南
  • Chrome原生工具网页长截图方法
  • 实现Johnson SU分布的参数计算和优化过程
  • STM32 vscode 环境, 官方插件
  • 进程通信:进程池的实现
  • JUC之CompletableFuture【上】
  • PythonDay31
  • 力扣(电话号码的字母组合)
  • 如何安全删除GitHub中的敏感文件?git-filter-repo操作全解析
  • STM32 定时器(主从模式实现 3路PWM相位差)
  • c#联合halcon的基础教程(案例:亮度计算、角度计算和缺陷检测)(含halcon代码)
  • 运维监控prometheus+grafana
  • 深入理解Java中的四类引用:强、软、弱、虚引用
  • 【科研绘图系列】R语言绘制多组火山图
  • 第六天~提取Arxml中CAN Node节点信息Creat_ECU