当前位置: 首页 > news >正文

斯坦福大学 | CS336 | 从零开始构建语言模型 | Spring 2025 | 笔记 | Lecture 5: GPUs

目录

    • 前言
    • 1. Outline and goals
    • 2. Part1: GPUs in depth - how they work and important parts
      • 2.1 Parallel scaling
      • 2.2 How is a GPU different from a CPU?
      • 2.3 Anatomy of a GPU (execution units and memory)
      • 2.4 Execution model and memory model of a GPU
      • 2.5 What about TPUs?
      • 2.6 Strengths of the GPU model
      • 2.7 Recap: GPUs – what are they and how do they work
    • 3. Part2: Making ML workloads fast on a GPU
    • 4. Part2: How do we make GPUs go fast?
      • 4.1 Control divergence (not a memory issue)
      • 4.2 Trick 1: Low precision computation
      • 4.3 Trick2: Operator fusion
      • 4.4 Trick3: recomputation
      • 4.5 Trick (?) 4: Memory coalescing and DRAM
      • 4.6 Trick 5: tiling
      • 4.7 Recap of part 2: making ML workloads go fast
    • 5. Part 3: Using what we know to understand Flash Attention
    • 6. Recap for the whole lecture
    • 结语
    • 参考

前言

学习斯坦福的 CS336 课程,本篇文章记录课程第五讲:GPU,记录下个人学习笔记,仅供自己参考😄

website:https://stanford-cs336.github.io/spring2025

video:https://www.youtube.com/playlist?list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_

materials:https://github.com/stanford-cs336/spring2025-lectures

course material:https://github.com/stanford-cs336/spring2025-lectures/tree/main/nonexecutable/5-GPUs.pdf

1. Outline and goals

今天我们要讨论的是 GPU,GPU 是驱动我们语言模型运转的核心引擎,因此正确理解 GPU 至关重要。如果你从未深入研究过支撑模型运行的硬件设备,它们确实会显得神秘莫测,今天的目标就是帮助大家揭开 CUDA 和 GPU 的神秘面纱

在这里插入图片描述

首先要澄清的是,大家不必完全理解上面这张图表,关键问题是:GPU 为何会变慢?它们的性能下降往往以非常隐晦的方式呈现,我们会在课程尾声尝试解读这张图表

随着矩阵乘法运算规模的增大,你可能会遇见计算速度变慢或变快等情况,这些难以预测的波动曲线会让人困惑。GPU 在某些特定数值倍数时运算飞快,而在其他倍数时却异常缓慢,这种现象确实令人费解,我们将尝试解析其中的原理

另一个关键点在于:我们需要掌握如何设计高效算法,大家应该都听说过 FlashAttention,这项技术通过巧妙优化 Transformer 内部的注意力机制使得处理更长上下文成为可能。

在这里插入图片描述

因此,或许你可以尝试开发类似 FlashAttention 这样的新型算法或创新实现方案,要达成这个目标我们需要理解哪些基础运算模块和核心组件呢?

这就是我们今天的两大学习目标。第一个目标是听完讲座后,你能对 GPU 的工作原理驾轻就熟,你需要掌握其运行机制;第二个目标是你要能游刃有余地优化算法中的关键环节,当你设计新架构时,应该能自信地尝试用 CUDA 进行加速

下面有一些优质资源可供大家参考学习:

在这里插入图片描述

首先是 Horace He 的技术博客,他分享了许多有趣的 GPU 冷知识,值得深入研读。举个例子,为什么充满零元素的矩阵乘法会比非零矩阵运算更快?具体原理可以去他的博客一探究竟

在这里插入图片描述

此外 CUDA 开发者论坛 gpu-mode 和谷歌发布的权威《TPU 技术手册》 scaling-book 也强烈推荐,毕竟本次讲座分享虽然力求全面,但终究只是对硬件原理的浅层探讨,因此今天我们将仅聚焦于硬件架构中的非并行计算组件

今天我们将深入剖析 GPU 这类单加速器的核心工作原理,重点解析其关键组件的工作机制。我们也会用极简短的篇幅提及 TPU—因为在某些方面它们的设计理念与 GPU 存在高度相似性,因此在 GPU 处的论述同样适用于 TPU

当我们掌握了 GPU 的硬件架构与执行模型后就能深入剖析其性能特性:理解 GPU 在特定工作负载下实现高速运算的机制以及导致性能瓶颈的根源。我们将系统性地解析其性能表现,最后环节将采用近乎实战演练的方式展开,我们将逐步解析 FlashAttention 的实现原理。

我们将整合所有已掌握的核心原理,带你逐层剖析 FlashAttention 的实现逻辑,你会看到这些知识如何环环相扣,最终形成完整的解决方案,以上就是今天讲座的要点。

2. Part1: GPUs in depth - how they work and important parts

2.1 Parallel scaling

大家大多都学过自然语言处理课程,如今的自然语言处理课程中都会涉及一定程度的 扩展定律(scaling laws) 内容,这里先做一个背景铺垫

在这里插入图片描述

众所周知,更强的算力有助于训练大语言模型,上面这张图表展示的是预训练阶段的扩展定律 [Kaplan+ 2020] 。业界普遍认为,算力越强,数据处理能力就越高,就能消化更多训练数据,从而训练出更庞大的模型。因此不难理解,深度学习确实至关重要,但真正推动性能跃升的,是更快的硬件,更高的利用率以及更优的并行化方案,由此可见,理解硬件原理至关重要

当我们开始思考算力扩展时,自然会追问如何才能实现算力扩展?如何让模型训练得更快?在半导体工艺发展的早期阶段,当我们思考如何提升 CPU 性能时,其速度提升主要遵循 登纳德缩放定律(Dennard scaling),根据摩尔定律,芯片上的晶体管数量每年都会翻倍,这种晶体管数量的倍增最终催生了登纳德缩放效应

随着晶体管尺寸不断缩小,芯片能够在更低功耗下实现更高时钟频率,从而持续提升性能。而这一黄金定律在 1980 至 2000 年间逐渐走向终结

在这里插入图片描述

正如 Hennessy 和 Patterson 这张图表所示,图中蓝点代表单线程性能,其增长曲线已明显趋于平缓。当然,晶体管数量并未真正开始下降,芯片的晶体管密度确实仍在持续攀升,但这已无济于事,单线程的吞吐量并未因此获得提升。因此,这意味着我们无法单纯依靠绝对计算速度的提升,我们必须通过并行扩展来弥补这一局限

深度学习与神经网络的发展轨迹正在追求单线程加速(即单纯提升绝对计算速度)转向并行扩展(通过同时处理大量计算任务实现性能突破)

在这里插入图片描述

这张由 Bill Dally 在其主题演讲中展示的计算性能增长图表令人尤为着迷,它揭示了从早期 K20 到 H100 显卡,每秒整数运算次数呈现出的超指数级增长轨迹,这条增长曲线呈现出惊人的指数级乃至超指数级攀升态势,因此我们必须深入掌握如何驾驭这条增长曲线,才能充分释放语言模型的全部潜力,这将成为我们追求的核心目标

2.2 How is a GPU different from a CPU?

中央处理器(CPU) 是大家开始编程时都会接触的基础组件,它采用的就是这种执行模型:程序运行时,会依次执行指令,在单线程环境下,程序会严格按顺序逐步执行,要支持这种执行模型,需要具备哪些条件呢?首先需要强大的控制单元,关键在于高速执行能力,由于存在大量分支跳转和条件控制逻辑,系统必须确保快速处理这些操作

在这里插入图片描述

如图所示,CPU 会将芯片的很大一部分面积用于 实现大规模控制逻辑和分支预测功能,由于线程数量有限,CPU 必须通过极高的单线程执行速度来弥补这一局限。如今确实出现了配备大量核心的 CPU 架构,但与 GPU 相比,其核心数量仍显微不足道

而 GPU 则截然不同,它配备了数量极其庞大的 计算单元(ALU 矩阵),图中这些绿色小方块代表的就是计算单元,芯片中用于控制逻辑的区域占比则小得多,仅需少量控制逻辑即可调度海量并行运算单元协同合作

从架构设计理念来看,上图清晰展示了 CPU 与 GPU 的核心差异,CPU 强调控制逻辑,而 GPU 侧重并行计算单元(绿色阵列),但若探究其设计目标,二者从根本理念上就存在差异:CPU 为通用计算优化控制流,GPU 为并行计算优化数据吞吐

可以说,CPU 的设计哲学是极致 优化单线程延迟—就像精密的瑞士钟表匠,专注于让每个齿轮(指令)以最短路径高效运转,短时间完成单个任务,如同短跑运动员,每一纳秒的延迟都至关重要。

在这里插入图片描述

如图所示任务 T1 至 T4,CPU 的处理策略是让每个任务像接力赛跑,T1 刚冲过终点,T2 立即起跑,通过指令流水线与分支预测等技术,确保每个任务以最短时延相继完成。因此当需要快速获得单个任务结果时,比如 T1 就能像离弦之箭般瞬间完成,这正是 CPU 通过深度流水线、乱序执行等设计创造的响应速度优势

在 GPU 中,优化的核心目标是 实现高吞吐量,它不执着于单个任务的完成速度,而是追求整体任务集的吞吐效率。就像工厂流水线,虽然每件产品的制造周期不变,但通过并行处理百万件产品,最终实现总产能的最大化。

为支撑这一目标,系统会创建海量线程,这些线程能够以纳秒级速度在休眠与唤醒状态间无缝切换。最终,尽管每个独立任务耗时更长,但 GPU 却能先于 CPU 完成 T1 到 T4 所有工作负载,就像快速分拣中心同时处理百万包裹,虽然单个包裹分拣速度不如专人递送,但整体吞吐量呈碾压性优势,因此它们的设计理念与目标定位存在本质差异

2.3 Anatomy of a GPU (execution units and memory)

GPU 背后至关重要的设计思想在于 它通过大量流式处理器(SM)并行执行运算,你可以把流式处理器想象成一个基本运算单元,在使用 Triton 这类编程工具时,其操作将直接作用于流式多处理器(SM)层级,在每个 SM 内部都包含多个流式处理器(SP),这些流式处理器能够并行执行大量线程

可以这样理解:流式多处理器 SM 内部集成了一组控制逻辑单元,它能自主决定执行内容。例如,它能够执行分支操作,流处理器 SP 的工作原理是对同一指令并行应用于多组不同数据。因此,你可以进行海量的并行计算,在该模型中,每个流式多处理器 SM 都是独立的控制粒度单元,单个流处理器 SP 即可独立完成大量运算

在这里插入图片描述

以 NVIDIA A100 这款 GPU 为例,其内部集成了 128 个 SM,这比大多数 CPU 核心的数量要多得多,这些系统内部都将集成大量流处理器和专用矩阵乘法运算单元,这就是基本的计算模型。

这里有两个关键点需要注意,可以把 GPU 想象成独立的计算机,但事实上,计算能力只是我们需要关注的两大核心要素之一。在当前阶段,内存的重要性甚至可能超越计算本身。从程序在 GPU 上的运行性能来看,内存的关键地位将持续凸显,要真正理解内存机制,就必须掌握 GPU 芯片的物理架构的运行环境下,内存与计算单元之间的物理距离开始产生决定性影响

这就像在百米赛跑中,起跑线位置的微小差异会显著影响最终成绩,接下来我们将展示硬件元件的物理排布方式,这种空间布局直接决定了我们应该如何优化内存访问策略以提升性能,就像城市道路规划会影响交通效率那样

内存单元距离流式多处理器 SM 越近,其访问速度就越快,这就像在图书馆中,伸手可及的书架总比需要走动的区域取书更高效。因此 GPU 内部存在极高速的内存层级,比如 L1 缓存和共享内存,它们直接内建于流式多处理器 SM 内部,就像在赛车引擎舱配置了触手可及的高性能燃料罐,这种内置存储的运行速度堪称极致

像寄存器这类需要频繁读写的数据最适合放在 L1 缓存和共享内存中,这就像把最常用的工具放在工作台伸手可及的位置

在这里插入图片描述

如上图所示,这些绿色区域就是流式多处理器,而 L2 缓存则像环绕在它们周围的环形高速通道,而图中这些蓝色区域则代表位于 GPU 芯片上紧邻流式多处理器的 L2 缓存,就像给每个计算单元都配备了专属的高速数据中转站

没错,它们虽然不在流式多处理器内部,但物理位置上依然靠得很近,就像在计算核心隔壁建了个专属数据仓库,确保能以最短路径快速存取。而这些缓存运算速度仍然相当快,虽然速度比片上缓存慢了约 10 倍,但整体性能依然出色,就像从超跑换成了跑车,虽不是极致速度,但完全能满足高性能计算的需求

而在芯片外部,以 3090 显卡或 PCIe 版本的 A100 为例,这些设备就像给超级计算机装上了外挂引擎,通过 PCIe 通道与主机系统保持高速数据交互

在这里插入图片描述

这是一张 PCIe 接口的 A100 显卡,其中放置着 GPU 芯片,而实际上 DRAM 内存就紧邻其侧,因此数据必须实际离开芯片,通过物理连接进行传输。从这张芯片示意图可以看到,边缘这些黄色接口就是连接器,这些都是高带宽内存(HBM)的连接接口,这些接口连接着位于 GPU 芯片外部的 DRAM 内存芯片。

在这里插入图片描述

从访问速度就能看到这种设计的局限,流式多处理器 SM 的片上内存访问速度极快,仅需约 20 个时钟周期,而访问二级缓存或全局内存则需要 200 到 300 个时钟周期,这种 10 倍速差会造成严重的性能瓶颈

因此,当某个计算任务需要访问全局内存时,SM 很可能会陷入无工作可做的闲置状态,此时所有矩阵乘法运算都已执行完毕,处理器只能进入空转状态,因此硬件利用率将显著下降。这将成为关键主题—从某种意义上说,内存设计才是真正的核心考量

2.4 Execution model and memory model of a GPU

理解 GPU 工作机制的核心要义在作业 2 中,你将实际编写针对 GPU 的高性能代码,因此必须深入理解 GPU 的实际执行模型工作机制,虽然这有些复杂,但并非难以理解,需要从三个粒度层面进行考量,具体分为线程块(blocks)、线程束(warps)以及线程(threads)三个层级,其粒度依次细化

线程块(block) 是线程的大规模集合体,每个线程块会被分配到流式多处理器 SM 上执行,可以将每个 SM 视为独立的工作单元,它是自主运行的独立单元。线程块会被分配到 SM 上进行处理,这就是最基本的处理单元

而在这些线程块内部,则包含大量并行线程,每个线程都对应着需要执行的具体计算任务。这些线程在执行时会以 线程束(warp) 为单位分组运行。这种执行单元就称为线程束(warp)

线程块(block)本质上就是线程的集合体,每次执行时,系统会从线程块中提取连续的 32 个线程作为一组来运行,这种 32 线程的集合就构成了一个线程束(warp)

在这里插入图片描述

如图所示,我们可以直观地看到这个执行过程,上面存在多个线程块,每个线程块会被分配到不同的 SM 上执行,在每个线程块内部,又会包含多个不同的线程束,而每个线程束则由大量线程组成,所有这些线程将在不同数据上同步执行相同的指令,这就是其基本执行模型

目前看来可能还略显晦涩,这些线程块,线程束和线程究竟是什么,它们将深刻影响性能表现并决定我们后续设计 CUDA 内核的方式,希望你能牢记这一点,后续讲解中我们还会适时回顾

以上就是 GPU 的 逻辑执行模型 概览,若能理解这一点,你就掌握了 GPU 执行计算的底层逻辑。GPU 同样具备 逻辑内存模型

在这里插入图片描述

上图展示的并非实体硬件结构,这仅是理解 GPU 编程的思维模型,寄存器是其中的核心组件。这些寄存器速度极快,专用于存储单个数值型数据,其内存架构包含局部内存、共享内存和全局内存三个层级,这种层级结构呈现存储容量递增、访问速度逐级递减的特性

程序代码可直接写入全局内存,此外还能写入常量内存,不过这种操作并不常见,每个线程均可访问自身的寄存器和共享内存,但跨线程块的数据交互必须通过全局内存实现,这一点在实际应用中至关重要。

因此在实际编程中,理想的线程设计应确保每个线程仅处理少量固定数据,因此需要先将这些少量数据加载到共享内存中,所有线程都能高效地访问这块共享内存执行计算,这种设计堪称理想地执行模型。

反之,若某个线程需要频繁访问分散的数据,就不得不调用全局内存,这将导致严重的性能瓶颈,这个核心理念将贯穿我们讨论 GPU 各类运算方式的始终,相信这个概念已经阐述清楚了

以上就是 GPU 架构最核心的四个要点概述

2.5 What about TPUs?

OK,现在插播一个技术支线—TPU,从宏观架构来看,TPU 与 GPU 的设计理念高度相似,虽然你可能永远不会直接操作 TPU,但理解这类替代性加速器的工作原理很有必要,它们在很多方面都与常见加速器极为相似

在这里插入图片描述

上面这张示意图展示了 TPU 的架构组成,这里有个关键组件,那就是所谓的 张量核心(tensor core),从概念上讲,你可以把张量核心类比为流式多处理器 SM,它们的功能定位非常相似。

这些组件本质上都是能够独立处理数据的原子化运算单元,其中包含一个标量单元,本质上就是控制中枢,它还能执行类似 CPU 的通用计算任务。配备的向量单元则专门处理向量运算,当需要对向量进行逐元素运算时,这个单元就是最佳执行者。

芯片还内置了专门用于矩阵乘法的超大专用计算区域,这个专门计算区被称为 MXU(矩阵乘法单元),同时配备了专为向量运算和流式多处理器 SM 优化的高速存储器。这两类存储器在芯片内部(或张量核心内存中)都能实现极速存取

此外,芯片外部还配置了高带宽内存(HBM),相信你已注意到它与流式多处理器 SM 的相似之处,外慢内快的存储层次:外部采用常规内存,而芯片内部集成高速存储器,芯片内部还集成了专用于矩阵乘法运算的硬件加速单元,其核心架构设计理念高度一致

具体差异将在下次的并行计算专题课上详细讲解,各加速器间的互联架构存在细微差异,此外 TPU 没有 warp 概念,只存在 block 级并行,其设计目标就是高效执行矩阵乘法

2.6 Strengths of the GPU model

GPU 之所以能取得如此大的成功,关键在于其卓越的可扩展性,只需增加更多流式多处理器就能轻松获得更强的计算能力,无需通过提高时钟频率来提升性能,自然也就避免了随之而来的散热难题

在这里插入图片描述

从编程角度看,CUDA 虽然令人望而生畏,但由于其编程模型设计合理,实际编写代码并没有想象中那么困难。具体来说,每个 SM 内部的工作机制是这样的:单个线程会对多组不同数据执行相同的指令操作,这种设计理念在概念层面非常容易理解

这种架构尤其适合处理矩阵运算等基础操作场景,这正是其简洁架构的精妙之处。最终,这些线程都极其轻量化,可以随时被暂停或重启,因此当需要等待其他线程执行或是需要终止当前任务并启动新进程时,这些轻量级线程都能轻松应对,这意味着线程几乎不占用额外状态资源,可以随时启停,从而使 GPU 能在每个 SM 内实现极高的利用率

图形处理器(GPU)在其发展初期,这种处理器长期未被应用于科学计算领域,但由于其可编程特性,研究人员最终成功利用早期 NVIDIA GPU 实现了高速矩阵乘法运算

在这里插入图片描述

这篇开创性论文 [Larsen+ 2001] 展示了如何通过纹理缓冲区等图形硬件特性实现快速矩阵乘法,堪称利用显卡进行高性能计算的早期典范,即便缺乏专门的矩阵乘法硬件支持,研究人员仍成功探索出了实现方案

但时至今日,NVIDIA 等企业已深刻认识到矩阵乘法运算的特殊重要性,在深度学习领域,矩阵乘法运算构成了最主要的计算负载,因此从某种意义上说,矩阵乘法堪称是 “天选运算”

在这里插入图片描述

这张图表展示了 NVIDIA 各代 GPU 的每秒万亿次浮点运算性能(TFLOPs),橙色折现代表矩阵乘法运算性能(matmul FLOPs)即执行矩阵乘法时可达到的算力水平,蓝色折线则对应非矩阵乘法运算性能(non-matmul FLOPs),这正是 NVIDIA 开始引入专为矩阵乘法优化的 张量核心(tensor core) 所带来的革命性突破

你会看到矩阵乘法性能与理论峰值性能之间存在巨大差距,因此,如果要设计任何神经网络架构,必须确保主要计算负载由矩阵乘法构成,我们在架构讲座部分也提到过这一点。因为在 GPU 上执行矩阵运算的速度比其他任何操作都要快好几个数量级,如果你设计一个不基于矩阵乘法的神经网络架构,那将会面临极其严重的性能瓶颈

最后想让大家记住一个基本事实,矩阵乘法运算速度极快是首要原则,但同样关键的是要理解 GPU 各组件之间的性能比例关系

在这里插入图片描述

这张图表清晰地展示了 GPU 各组件或者说大语言模型训练系统中各模块的性能扩展速度对比,蓝色折线代表 GPU 与主机之间的数据传输带宽即 GPU 所连接的服务器主机,无论是 PCIe 总线、NVLink 高速互联还是其他先进连接技术,这些技术的性能虽在提升,但增速相对缓慢,这张图表采用标准化处理,以初代互联技术为基准(1x)呈现带宽增长趋势

绿色折线则对应全局内存的访问速度,从 GDDR 显存升级到 HBM2E 高带宽内存,速度实现了质的飞跃,速度快了 100 倍,但这种扩展速度依然较慢

图中灰色线代表算力扩展趋势,这里显示的是浮点运算次数,特指矩阵乘法运算的 FLOPs 值,这反映了算力提升的实际速度,这种增速堪称惊人,速度提升了 1 到 10 万倍

在算力提升的初期阶段,计算性能的瓶颈往往集中在浮点运算能力上,当时的矩阵乘法运算完全受限于浮点运算能力的不足,而如今最先进的 A100 显卡,其图像处理器已实现惊人的运算速度,当前性能瓶颈已转向内存带宽,因为内存性能的提升速度远跟不上计算单元的演进,这种发展趋势在未来仍将持续

动态随机存取存储器(DRAM)的性能提升面临重大技术挑战,这种性能差距将持续扩大。因此,设计硬件友好型算法时,必须将内存优化作为核心考量,我们会反复强调这一点,这是 GPU 架构的核心议题之一

2.7 Recap: GPUs – what are they and how do they work

上面我们为大家介绍了 GPU 的技术要点,简而言之,GPU 本质上是超大规模并行计算架构,它们采用单指令多线程(SIMT)架构实现并行计算,GPU 内部由大量被称为流式多处理器 SM 的计算单元构成,这些单元相当于核心处理器

计算性能与矩阵乘法运算速度呈指数级提升,其增速已远超内存带宽的发展,这一特性正是 GPU 架构设计的核心特征之一,但 GPU 确实配备了高速缓存存储器,并非所有环节都存在性能瓶颈,优化空间始终存在。

GPU 采用分层存储架构,因此某些层级的存储单元具有极高的访问速度,而其他层级的存储则相对较慢,若能巧妙利用这种分层存储结构,我们就能实现极高的运算效率

这就是 GPU 架构的核心要点,只要牢记这些特征,就能清晰理解接下来要讨论的性能的优化要素

3. Part2: Making ML workloads fast on a GPU

接下来我们要做的就是让机器学习任务在 GPU 上跑得飞快,接下来我们将从下面这张图表开始讲解:

在这里插入图片描述

我们的目标之一就是要准确理解这张图表所表述的内容,现在我们要做的是进行方阵的乘法运算,横轴表示参与乘法运算的方阵的维度大小,纵轴则代表每秒执行的计算操作次数

因此,纵轴可以理解为 硬件利用率,随着矩阵规模不断扩大,硬件利用率会持续提升,毕竟更大的计算量能让硬件火力全开,这样就能完全抵销任务调度等操作带来的额外开销,但实际运算过程中会出现各种异常现象

可以看到图中出现了 4 条不同的性能曲线,这些曲线都是呈现出难以预测的波动形态,因此我们需要深入解析这些曲线波动的具体成因

在这里插入图片描述

首先观察这张图表时,你会发现它的走势大致呈现这样的特征,如果你学过计算机体系结构课程,应该能认出这就是典型的 屋顶线模型(roofline model),屋顶线模型的核心观点是:当我们考察吞吐量或利用率时,会发现存在两种典型状态,第一种是内存瓶颈状态,对应上图中左侧曲线部分,第二种是计算瓶颈状态,对应图中右侧曲线部分

可以这样理解,右侧区域表示计算单元已达到满载状态,所有矩阵乘法单元都在持续进行运算,而这条对角线上出现的是典型的内存带宽瓶颈

因此,我们的计算能力受限于运算强度,即每字节数据能承载的浮点运算次数(FLOPs/byte),我们需要避免陷入左侧这种受限于内存带宽的情况,我们更希望处于右侧区域,这样就能充分释放所有计算单元的性能潜力,这本质上就是性能优化的终极目标

理想情况下,这个屋顶模型的走势应该如图所示,对角线部分代表着内存带宽限制下的性能上限,而顶部这条水平线则代表计算单元的理论峰值性能,这算是解开了性能优化的第一环

4. Part2: How do we make GPUs go fast?

事实证明,这个问题远比想象中复杂,简而言之,我们必须确保避免不必要的内存访问,要尽可能减少对低速全局内存的访问次数。但事实证明,要实现这一点,我们需要运用一系列复杂的优化技巧,稍有不慎就会踩中性能陷阱,导致程序运行效率大幅下降

首先要明确的是,性能瓶颈往往不在内存本身,这里简单提一下,这种情况并不常见,我们先把这个因素排除掉。接下来我们将重点讨论真正影响 GPU 性能的五个核心要素:

  • 1. Control divergence (not a memory bottleneck…)
  • 2. Low precision computation
  • 3. Operator fusion
  • 4. Recomputation
  • 5. Coalescing memory
  • 6. Tiling

4.1 Control divergence (not a memory issue)

首先要讨论的是条件分支问题,正如之前提到的,GPU 采用 SIMT(单指令多线程)并行架构,因此,同一个线程束(warp)中的所有线程都会执行相同的指令,但处理的是不同的数据,那么当我们编写下面这样的代码时会发生什么呢?

在这里插入图片描述

这里有一个 if 条件判断语句,当线程索引小于 4 时,执行特定操作,若线程索引大于或等于 4,则执行其他操作。这个简单的条件判断模型在 GPU 上运行时,前四个线程将执行 A 指令,其余四个本该执行 else 分支的线程将被迫暂停,随后这另外四个线程将被激活,执行 X 指令,原先的四个线程将进入休眠状态,整个执行过程就在这些指令间交替切换

为什么会这样呢?原因在于这些不同的线程无法同时执行 A 和 X 指令,正如我们反复强调的,所有线程必须执行相同的指令,因此,单个 warp 内的条件语句可能造成严重的性能损耗,它们会强制暂停所有未执行主控制流的线程

这是我们要讲的唯一与内存无关的要点,道理其实很明显,在大规模并行计算单元中应该尽量避免使用条件判断语句,不过既然已经解决了这个问题,接下来我们要讨论的技巧基本上都与内存优化有关

4.2 Trick 1: Low precision computation

首先要说的是 降低计算精度,这可是个关键技巧,这个技巧应该贯穿整个优化过程。说到这里,让我们回顾一下 Bill Dally 的这张图表,其中暗藏玄机

在这里插入图片描述

这张图表看起来相当漂亮,因为各项数值节节攀升,但若细究这些年来推动 GPU 发展的真正动力,你会发现核心因素其实是数值表示方式的演进,从 FP32 到 FP16 再到 INT8,精度不断演进,仅通过不断降低 GPU 运算精度,就能实现数量级的性能提升

让我们来解释一下这为何如此关键,当计算数据和权重等元素的位数减少时,需要传输的比特量就会大幅降低,因此即便从全局内存中读取这些数据,传输瓶颈的影响也会显著减弱

举个简单的例子,我们以简单的逐元素运算为例分析其算术强度,这里用 ReLU 函数来演示,即 x = max ⁡ ( 0 , x ) x=\max (0,x) x=max(0,x)。现在对一个长度为 n n n 的向量执行该运算,假设我们直接使用 32 位浮点数(FP32)进行处理,那么需要进行多少次内存访问呢?

首先需要读取输入向量 x x x,然后需要写入 x < 0 x<0 x<0 的判断结果,这些操作都使用 32 位浮点格式,这样总共涉及约 8 字节的数据量,总共执行了多少次运算?这里需要执行一次 x < 0 x<0 x<0 的比较运算,执行了一次浮点运算(FLOP)

这意味着每执行一次浮点运算需要处理 8 字节数据,若改用 16 位浮点运算,虽然计算强度(FLOPs intensity)保持不变,但内存访问量会发生变化,现在每浮点运算(FLOP)对应 4 字节数据量。从某种意义上说,在能够使用 16 位浮点运算的前提下,相当于免费获得了双倍内存带宽,这正是许多系统设计的核心考量

本次作业的重点是:大家需要尝试混合精度或低精度训练各种方案,关键在于并非网络结构和训练算法的所有环节都适合采用低精度计算,以矩阵乘法为例说明:

在这里插入图片描述

在混合精度矩阵乘法运算中,通常将输入数据设为 16 位精度,这些属于低精度计算范畴,而实际矩阵乘法则需保持完整的 32 位精度,这样做非常有必要,因为在累加部分和(partial sums)的中间计算过程中,保持高精度至关重要,因此需要使用 FP32 累加器来完成这一过程,随后张量核心会输出 FP32 结果

在这里插入图片描述

你可以根据需要将其降级转换回 16 位精度,我们的输入数据采用 16 位精度,但像累加这类运算,我们最好使用 32 位精度来完成。优化手段多种多样,部分运算可采用 16 位存储格式,而某些运算则需要更高精度,因此需要根据需求选择保持 FP32 或 FP16 精度

对于需要更大数值范围的运算(如指数函数),则需采用更高精度的格式,若动态范围不足,可能导致数值爆炸或归零问题,因此这类运算适合采用 bf16 格式,要确保模型在低精度训练时保持稳定需要极其精细的工程实现,但若能实现低精度训练,效果将非常显著,当内存成为瓶颈时,从 32 位降至 16 位可使吞吐量直接翻倍

4.3 Trick2: Operator fusion

另一个关键点,也是很多人提到要编写 CUDA 内核时真正考虑的,在于 操作融合(Operator fusion),这个概念既直观又有趣,是很容易想到的优化方向,理解 GPU 和内存工作机制的一个形象比喻,可以参考 horace 提出的工厂流水线示意图:

在这里插入图片描述

想象你有一座工厂,这座工厂就相当于计算单元,它接收小方块零件作为输入,然后输出小三角零件,如果你只增加计算单元(工厂),但连接内存与计算单元的传送带(带宽)有限,那么第二座工厂就根本无法投入使用,整个系统的性能仍然受限于数据从内存传输到计算单元的速度,这就形成了性能瓶颈

当然,这个道理你可能早就明白,虽然我们一直强调瓶颈问题,但有个更隐蔽的性能陷阱往往被忽视,这种模式在不知不觉中产生大量计算开销,那就是左值计算模式

在这里插入图片描述

设想这张图的左侧代表内存位置,而右侧则是你的计算单元,计算过程从一个方块开始,这些方块数据需要从内存传输到计算单元,执行相应运算操作,将方块转换为三角形态,完成转换的三角数据被写回内存。

此时发现仍需调用这些三角数据,于是再次将数据传回计算单元,三角数据随即被转换为圆形数据,如此循环往复,数据在计算单元与内存间往复传输,循环不已,这或许可称为最原始的实现方式,如左侧图所示

若直接在 GPU 上采用这种原始计算方式,并将结果直接写回全局内存,最终会得到这样的效果。细数数据往返传输的次数,这种做法的效率实在堪忧,这会导致巨大的内存开销。现在如果你观察右侧图时应该能意识到,左侧图这种计算方式真的可行吗?

由于不存在数据依赖关系,计算流程完全可以从方形到三角再到圆形,最后输出矩阵结果。整个计算过程都可以完全保留在运算单元内部完成,这就是右侧示意图所示的情况,这正是融合内核的思维模型

数据会依次经历一系列连续的操作处理,与其将数据写回存储单元,我们的策略是尽可能在同一个计算单元内完成所有运算,直到必须将结果传回内存时才进行数据转移,这就是内核融合的核心思想。

举个简单的例子:如果直接编写未经优化的代码,可能会产生一系列低效的核函数调用,比如编写一个简单的神经网络模块时遇到的情况:

在这里插入图片描述

假设我们编写了一个神经网络模块,输入 x x x 后能同时输出 sin ⁡ 2 ( x ) + cos ⁡ 2 ( x ) \sin ^2(x) + \cos ^2(x) sin2(x)+cos2(x),代码很简单,如上所示,运行这段代码时,PyTorch 生成的计算图结构大致会是右侧图所示的样子

这样就会触发一连串的 CUDA 核函数调用,首先会启动一个 CUDA 核函数来计算 sin ⁡ ( x ) \sin (x) sin(x),接着会依次启动核函数计算 cos ⁡ ( x ) \cos (x) cos(x) sin ⁡ 2 ( x ) \sin ^2(x) sin2(x) cos ⁡ 2 ( x ) \cos ^2(x) cos2(x),最后计算 sin ⁡ 2 ( x ) + cos ⁡ 2 ( x ) \sin ^2(x) + \cos ^2(x) sin2(x)+cos2(x)。整个计算过程需要在 GPU 和主机内存之间频繁交换数据,这正是我们之前展示的左侧示意图所描述的场景

但如果你更聪明一些,无论是自己编写 CUDA 核函数还是使用类似 torch compile 这样的工具,就会立刻意识到这五个运算其实存在优化空间,这些运算仅需占用极少量的显存,因此可以将它们融合成单一运算

在这里插入图片描述

在 GPU 单线程内完成所有计算,完全无需回写全局显存,这类简单的运算融合完全可以通过编译器自动实现,正如我们刚才提到的 torch compile 工具。如果你尚未采用这种优化方式,现在正是最佳时机,强烈建议全面考虑在所有场景下使用 torch compile,我们也会在作用中展示 torch compile 的使用方式,这个功能相当实用

4.4 Trick3: recomputation

OK,关于计算精度和操作融合的部分我们已经讲完了,下面我们来看另一种优化手段,它叫做 重计算,其核心理念是通过增加计算量来减少内存访问需求。还记得最初的反向传播的内容吗?这个知识点其实源自 CS221 课程

在这里插入图片描述

我们从最底层的输入开始处理,也就是图中黄色的部分,然后将激活值向上传播,这些同样是树状结构中的黄色数值,接着我们反向计算雅可比矩阵,这些就是边上的绿色数值。接下来,我们将通过反向传播来计算梯度进行乘法运算

因此,我们将雅可比矩阵与激活值相乘,反向传播梯度,仔细想想,前向传播生成的这些黄色数值必须被保留下来,这些数值随后会被存储起来,随后需要将这些数据从全局内存中取出(即之前存储的位置)并加载到计算单元中。

从实现机制上来说,这个过程必不可少,但这可能导致海量的内存读写操作,实际上,我们或许能设法规避这个问题,让我们通过一个具体案例来说明如何通过计算优化来提升性能,下面再举个简单的函数示例:

在这里插入图片描述

我们将三个 sigmoid 函数进行堆叠处理,左侧示意图是前向计算图,这正符合三个 sigmoid 函数逐层堆叠的思维模型。现在来看这个计算图:我们将计算这些 sigmoid 函数的值,并存储中间激活值 s1 和 s2(即各层 sigmoid 的输出结果),最终得到输出结果,至此完成前向传播过程

然而,这个反向传播过程实在不太理想,在构建反向计算图时,需要调取存储的 s1 和 s2 激活值,将输出端回传的梯度值接入,并推进这个反向计算流程,最终就能求得输入 x 的梯度。

因此完成反向传播需要执行三次内存读取和一次内存写入操作,而前向传播仅需对输入 x 执行一次内存读取,同时需要为中间变量 s1、s2 和最终输出 out 执行三次内存写入,这样的内存读写操作量着实不小,总共需要执行八次这样的内存操作

由于完全没有矩阵乘法运算,当前算法的算术强度非常低,重计算的核心思路就是:我们根本不需要存储这些中间激活值,我们完全可以避免将这些数据写入内存,我们可以在反向传播过程中实时重新计算这些中间结果

在这里插入图片描述

因此在新的前向传播过程中,我们不再存储 s1 和 s2 这两个中间结果,输入 x 经过 sigmoid 函数计算后直接得到输出结果。这样只需要对 x 进行一次内存读取,对输出结果进行一次内存写入,现在进行反向传播时,我们就不再需要保存激活值了

在反向传播过程中,我们将同时接收来自上层的梯度信号 dout 以及前向传播时的输入 x,因此需要进行两次内存读取操作,随后在流式多处理器 SM 的本地内存中实时计算各个 sigmoid 函数的梯度,并将计算结果写入反向计算图

我们将在本地内存中实时重新计算 s1、s2 和输出值,由于采用这种计算方式,此出完全避免了全局内存读取操作,最后只需执行一次内存写入操作来存储 dx

现在对比两种方案,在完成完全相同的计算任务时,当前方案仅需 5/8 的内存访问量,我们为此付出的代价是需要重新计算这三个 sigmoid 函数,但若原本就因内存带宽受限而处于闲置状态,这种取舍就非常划算,用过剩的计算资源换取紧缺的内存带宽,简直是稳赚不赔的买卖,这就是资源置换的经典策略—用富余资源换取紧缺资源。

当然,这种情况与之前存在本质区别,这与梯度检查点技术和通过重计算激活值来节省内存的手法如出一辙,但二者的应用动机截然不同,这么做是为了提升执行速度,而不仅仅是为了解决内存不足的问题,这是同一种技术,但目标不同。

4.5 Trick (?) 4: Memory coalescing and DRAM

而接下来这个优化点非常有意思,在深入研究 GPU 硬件模型和 DRAM 工作原理之前,甚至可能都不知道它的存在。

GPU 中的全局内存(即 DRAM)速度其实极其缓慢,为了提升访问速度,硬件层面会实施特定的优化机制,DRAM 在硬件层面的一个关键优化是:当你读取某个内存地址时,实际返回的并非单一数据值,系统会一次性返回整块连续的内存数据,这种机制被称为突发传输模式(burst mode)

在这里插入图片描述

假设我们现在要读取这个大内存块的第一个数据,内存系统不会只返回单个数值 0,而是会将 0、1、2、3 这一组数据全部返回,系统会一次性返回四个连续数值,就像这样:“数据都在这儿,拿去吧”。

系统预判你很快就会用到 1、2、3 这些相邻数据,因此每个地址空间都被划分为称为 “突发传输区块” 的单元,系统会直接返回整个突发传输区块,而不仅是你请求的那部分数据。

这种机制乍看可能令人费解,为什么内存会在你只请求 1 字节时,免费多给你 3 个字节呢?这背后有个非常有趣的硬件原理:当内存寻址并准备发送信号时,这些字节必须经过放大器传输,这才是最耗时的环节。而一旦完成这个步骤,后续就能免费获取大量字节数据,因此就产生了这种突发传输机制

这种机制本质上是为了掩盖数据从存储位置移动到放大器这个更耗时的关键步骤,但无论如何,这意味着只要内存访问模式得当,我们就能大幅提升内存访问速度。

在这里插入图片描述

因此,如果我们想要读取这里的整个数据块,但采用随机访问顺序,那么基本上需要进行大约与查询长度等量次数的内存请求,但如果我们直接访问第一个值,就能一次性获取整个突发传输区间的所有数据。

接着如果访问第 4 个值,就能立即获取第二个突发传输区间的完整数据块。因此只要精心设计内存访问模式,仅提取每个突发数据块中的必要部分,理论上就能实现四倍的吞吐量提升,这就是所谓的 内存合并访问(Memory coalescing) 技术。

当同一个线程束(warp)中的所有线程落在同一个突发传输区间内时,硬件智能架构与编程模型会自动将这些内存请求合并处理。系统不会分别请求 0、1、2、3 号数据,而是将它们智能合并为单次请求—只需获取 0 号数据,这样就能通过 DRAM 的突发传输模式一次性读取 0、1、2、3 号全部数据

值得注意的是,一个 warp 由 32 个有序排列的线程组成,因此,线程束内的内存访问操作是同步执行的,但这些线程束读取突发传输区段时,通过优化可以实现一次性获取全部 4 字节数据,而非逐字节单独提取,这样就能将内存吞吐量提升至原来的 4 倍,这些看似简单的优化手段,实则至关重要。

假设我们要进行矩阵乘法运算:

在这里插入图片描述

设想下现在我们要用两种方式读取矩阵数据,可以按行遍历读取,这样每个线程将逐行遍历,或者也可以按列顺序读取,此时每个线程将沿列方向遍历。

结果表明,左边这种跨列遍历的方式运行速度会相当慢,因为内存读取操作未能实现合并访问。而如果采用右侧这种线程按行递增的访问方式(即线程逐行向下读取),就能实现内存读取的合并操作,大家可以思考一下为何如此

在这里插入图片描述

在上面的示意图中,有一组线程正试图从左到右进行数据访问,因此每个线程都会尝试加载第一个元素,然后再下一个时间步,线程会依次加载第二列、第三列、第四列的元素,以此类推。

那么当第一步执行时会发生什么呢?第一步执行时,第一个线程会在此处加载数据,紧接着第二个线程会在相邻位置加载数据,这些数据根本无法实现内存合并,它们读取的是不同的突发内存区块。

这意味着我们必须读取整块内存才能执行任何操作,反之,如果按列方向读取,所有线程都将从同一个突发内存区块中读取数据,这样只需执行一次内存读取操作就能一次性获取全部所需数据,这是非常底层的优化,但至关重要,如果内存遍历顺序完全错误,实际获得的内存访问速度会远低于预期。

4.6 Trick 5: tiling

OK,那么这就引出了最后一个关键要点,这就是分块处理(tiling)的核心思想,分块处理的核心理念是通过将内存访问集中分组,从而最大限度减少全局内存的访问次数

为了说明这个概念,我们将通过矩阵乘法的例子来逐步解析,希望通过这个例子,能向你说明为什么原始的矩阵乘法算法会存在严重缺陷,随后我们将展示采用分块处理的改进方案,你将直观地看到这种优化如何显著减少所需的全局内存读取次数

让我们从下面这个最简单的矩阵乘法算法开始讲起:

在这里插入图片描述

假设我们有两个矩阵,左侧是矩阵 M,上方是矩阵 N,现在要计算矩阵 M 和矩阵 N 的乘积,我们需要遍历矩阵 M 的行和矩阵 N 的列,然后计算内积并将结果存入矩阵 P 中

这里我们还为每个线程编写了对应的计算逻辑:

在这里插入图片描述

线程 0 对应着它们存储输出的位置以及访问每个独立元素的顺序,请注意,这里出现的问题是内存访问没有实现合并操作,比如 M 中的行矩阵,它们的访问顺序会导致无法实现内存合并。

而且还出现了重复的内存访问,可以看到 M 0 , 0 \text{M}_{0,0} M0,0 在第一个线程中被访问,后面又再次访问 M 0 , 0 \text{M}_{0,0} M0,0,而 N 1 , 0 \text{N}_{1,0} N1,0 则在两个不同的线程中被重复访问,这些数值会从全局内存被反复读取到多个不同的线程中,这样的操作很可能会导致严重的性能瓶颈

这里的关键问题是:我们能否避免过多的全局内存读写操作?这才是我们最希望实现的优化目标。让我们先说明下理想的优化效果,再详细解释具体算法

理想情况下,我们希望用一段集中的时间将数据块从全局内存加载到更快的共享内存中,我们希望在共享内存中完成大量计算后,就能彻底处理完这部分数据,这就是最理想的优化效果,这样就最大限度地减少了全局内存访问。

那么在矩阵乘法运算中该如何实现呢?

在这里插入图片描述

现在,我们将对 M 矩阵和 N 矩阵进行分块处理,我们要将它们精确地分割成若干分块,如上图所示,我们已将其分割为 2x2 的分块矩阵,这样就得到了一个 2x2 的 M 分块矩阵和一个 2x2 的 N 分块矩阵,这样每个大矩阵内部就形成了若干较小的子矩阵

假设我们的共享内存容量足够大,能够将这些子矩阵完整地载入每个流式多处理器 SM 中,如此一来,我们就获得了一个极其简洁的计算实现方案。

接下来,我们将首先加载位于左上角这个 M 0 , 0 M_{0,0} M0,0 分块矩阵,同时,我们会把 N 0 , 0 N_{0,0} N0,0 分块矩阵也加载到共享内存中。现在,我们就可以计算这些部分和了,我们可以计算 M 0 , 0 \text{M}_{0,0} M0,0 M 0 , 1 \text{M}_{0,1} M0,1 N 0 , 0 \text{N}_{0,0} N0,0 N 1 , 0 \text{N}_{1,0} N1,0 的行乘积,然后将结果累加到 P0 中

对于所有能填充到这里的子矩阵,我们都可以采用同样的处理方式,当这两个分块完全处理完毕后,我们就可以在这里加载新的分块数据。随后,当新的 M 分块和 N 分块数据加载到共享内存后,我们就能重复这个计算过程,这样就能将部分和结果逐步累加到 P 矩阵中,如此一来,我们确实大幅减少了需要进行的全局内存访问量。

我们会一次性将尽可能多的数据加载到共享内存中,在这个分块上完成所有可能的子矩阵运算,接着再处理下一个分块。更妙的是,由于加载的是完整的分块数据,我们可以按任意顺序遍历这些子矩阵,无论是列优先还是行优先存储方式,因此在将分块数据从全局内存加载到共享内存时,就能实现所有内存访问的合并操作,采用分块访问策略可谓一举多得

我们可以进行一些分块计算的数学推导:

在这里插入图片描述

假设我们有矩阵 A、矩阵 B 和矩阵 C,假设这些矩阵都是方阵,且维度均为 NxN,假设分块尺寸为 TxT,我们可以通过简单的分块计算来分析这个过程

假设我们要进行一个 NxN 的矩阵乘法运算,如果采用非分块矩阵乘法,即直接逐行逐列计算,那么每次处理时每个输入元素都需要重新加载,这些数据都必须从全局内存中读取,因此每个输入元素都需要从全局内存中读取 N 次。

如果采用分块矩阵乘法,那么全局内存的读取操作将以分块为单位进行,每个输入元素只需从全局内存中读取 N/T 次,其中 T 为分块大小,而在每个分块内部,每个元素会被读取 T 次

当然,我们正在进行的是矩阵与矩阵的乘法运算,因此无法减少总的读取次数,但可以通过将这些读取操作转移到高速的共享内存中来实现优化,因此需要执行 T 次共享内存读取以及 N/T 次全局内存读取

这种设计非常高效,当共享内存足够大以存储大型分块时,就能将全局内存的数据传输总量降低至原来的 1/T,分块策略在处理矩阵运算时能展现出惊人的威力,你可以将数据转移到共享内存中。

分块策略的实现相当复杂,这正是导致 GPU 和矩阵乘法性能表现存在诸多困惑的根源所在。可能出现这样一种情况,当我们开始采用分块策略时,就不得不考虑离散化的问题

在这里插入图片描述

假设我们采用 128 的分块尺寸,这个分块尺寸看起来既规整又理想,当处理 256 尺寸的完整矩阵时,这个分块方案确实很完美,这样正好形成 2x2 的分块布局,数据加载非常顺畅

现在假设列方向存在 257 的尺寸,此时情况就变得棘手了,为了完整覆盖这个矩阵,我们需要使用六个分块,而右侧的两个分块将存在严重的稀疏问题,这些分块中的数据量少得可怜。问题在于,每个分块都会被分配到一个流式多处理器 SM 上执行,也就是说,每个分块都对应一个线程块,每个线程将在各自的分块内执行运算

这样一来,右侧那两个分块几乎派不上什么用场,这些 SM 基本上就处于闲置状态,如果计算资源接近瓶颈,就需要在 SM 之间更均衡地分配负载,因此必须优化分块尺寸,避免出现这类负载不均的情况。

但实际上,确定分块尺寸需要考虑诸多复杂因素,切记,必须确保内存访问是合并进行的,因此必须慎重考虑这个问题;必须确保不超过共享内存的容量限制,因此分块尺寸不能过大;同时需要合理划分矩阵维度,尽可能做到均匀或接近均匀分配,这样才能避免出现 SM 在最后阶段利用率不足的情况

另一个极其复杂的问题是 分块策略与突发访问区间的交互影响

在这里插入图片描述

假设我们有一个这样的矩阵布局:其中每个规整的突发访问区块都与分块区域完美对齐,要读取这个分块区域,只需获取四个不同的突发访问区块即可,这样就能完整获取整个分块区域。

现在假设我们增加一个元素,矩阵的存储布局就会导致分块起始位置和突发访问区块开始错位,此时加载分块区域时,首先会加载第一行这部分数据,这部分加载非常理想,整个首行数据都能通过突发访问区块一次性获取。

但在第二行时,数据实际上跨越了两个不同的突发访问区块,因此必须执行两次读取操作才能获取完整的第二行数据,后续行依此类推,这就导致了内存访问次数实际上翻倍了,仅仅因为在矩阵末端增加了一个元素。

这相当于打乱了我们原本设计的突发访问区块与内存对齐布局,因此本质上,当分块尺寸或矩阵大小不是突发访问区块的整数倍时,很容易出现这种行数据与突发区块错位的情况,最终导致必须执行双倍的内存访问操作

解决这个问题的关键是 通过填充(padding)来调整矩阵尺寸,使其成为规整的整数倍,从而确保突发访问区块与分块尺寸完美对齐,这已经深入到相当底层的技术细节了,但若想真正压榨出矩阵乘法的每一分性能,这类底层优化正是必须考虑的关键所在。

确实,像 torch compile 这类工具以及 CUDA 针对矩阵乘法的各种优化,本质上就是在实现我们刚才讨论的这些底层优化策略,这正是提升性能的关键所在。这种矩阵计算的复杂性最终会体现在实际场景中:

在这里插入图片描述

就像 Andrej 的这条推文所说:对 nanoGPT 最显著的优化,就是将词表大小从 50257 调整到 50304(即最接近的 64 的倍数),这直接带来了大幅提升的计算核心占用率,大约有 25% 的速度提升。

这到底是怎么实现的呢?这样一来,我们又回到了最初的那个谜题:

在这里插入图片描述

现在我们将要解读这张性能曲线的形成原理,相信看完之后,你会觉得矩阵乘法的性能表现不再神秘莫测。

最基础的部分其实非常简单,我们理解了计算强度这个概念,这正是我们最开始提到的屋顶线模型(Roofline Model),在计算强度达到 1536 之前,矩阵乘法运算量还不足以充分利用计算资源,仅仅是加载矩阵数据并进行最基本的输入输出操作。

在低于这个计算强度阈值时,数据搬运操作就会成为性能瓶颈,因此一旦超过这个临界点,计算吞吐量就会急剧下降,内存带宽已无法满足计算单元的需求。

现在看右侧区域,理论上说,若绘制上包络线,顶部区域代表的是理论可达的最高性能极限,因此在这个区域,确实有可能让所有计算单元满载运行,从而实现极高的性能表现。但如果矩阵尺寸设置不当,就可能陷入底部这些性能异常的低效区域,而在每个性能区间内,都可能意外跌入这种诡异的性能低谷,因此我们需要思考,为什么会出现这些不同的性能落点?

在这里插入图片描述

首先,第一条分界线(也就是上图中的第一条线),反映的是分块对齐问题,观察这些倍数关系时,此时图中已经根据矩阵尺寸的可整除性为每条分界线标注了不同的颜色,若矩阵尺寸能被 32 整除,则处于理想性能区间,此时数据点位于图表上方的紫色区域,若能被 16 整除,数据点仍会保持在上方区域,依此类推

若数值无法被任何数整除即 K=1 时,数据点将处于最下方位置,这种情况下,矩阵乘法运算吞吐量会很不理想。关键问题在于当 K 降为 2 甚至 1 时,你将被迫面临无法以对齐方式执行分块读取的困境,突发读取的规整模式被彻底破坏,这会导致严重的性能问题,这确实是个棘手的问题

不过,这还只是谜题的一部分,这个谜题还有另一部分尚未解开,另一个问题应该出现在图中橙色标记的范围内,当你放大观察这段曲线时,会看到性能从某一点开始断崖式下降

在这里插入图片描述

这种断崖式下跌简直让人匪夷所思,仅仅将矩阵维度增加一点,性能怎么会暴跌到这种程度呢?这是个有趣的谜题,下面我们来逐步解析这个谜题

这种情况通常发生在矩阵维度从 1792 增加到 1793 时,为什么会出现这种情况呢?假设我们使用的分块尺寸是 256x128,这个尺寸选择相当合理,有趣的是,这些 GPU 中的矩阵乘法单元天生就适合处理尺寸在 128 左右的矩阵运算,因此 256x128 是一个非常理想的分块尺寸

这意味着维度大小是 1792 的矩阵需要划分多少个分块呢?总共有 1792 256 × 1792 128 = 7 × 14 = 98 \frac{1792}{256} \times \frac{1792}{128} = 7 \times 14 = 98 2561792×1281792=7×14=98 个分块,如果我们把这个值增加 1,那么每个坐标值都需要向上取整,这样一来分块数量会大幅增加,将达到 8 × 15 = 120 8 \times 15 = 120 8×15=120 个,可见分块数量出现了显著增长。

问题在于,我们不仅大幅增加了分块数量(其中部分分块的利用率会降低,这本身就很糟糕),更严重的是 A100 显卡仅仅拥有 108 个流式多处理器(SM),回顾 GPU 的执行模型,SM 作为核心执行单元能够实现并行运算。当需要 98 个 SM 时,它们会全部投入并行运算,所有 SM 单元都能被同时调度执行,这样就能实现极高的硬件利用率

当任务分块数达到 120 个时,分块数量就超过了 SM 的总数,因此其中 108 个分块会并行执行,然后系统会重新调度,确认还有可用的 SM 资源,在极低的硬件利用率状态下,系统会执行剩余的 12 个分块任务,并等待其全部完成。

这种情况会导致严重的性能劣化,从硬件利用率曲线来看,系统会在前期保持较高的利用率,然后利用率会断崖式下降,最终在低效状态下完成剩余任务,这种现象被称为 波阵量化效应(wave quantization)

因此最理想的情况是,分块尺寸要么远大于 SM 数量,要么就完全不采用分块策略。这种情况就像勉强超过 SM 数量阈值反而会引发额外的量化误差问题

4.7 Recap of part 2: making ML workloads go fast

虽然这些都是底层细节,但正如我们在课程中反复强调的:语言模型与深度学习的精髓,恰恰在于对细节的极致把握,正是这种对细节的极致专注,才使得研究者能够将大语言模型扩展到超大规模,同时获得卓越性能,即便你并非系统工程专家,这些知识也值得掌握。

那么关键技巧究竟有哪些?核心思路如下:首要原则是必须减少内存访问次数,实现这一目标的方法多种多样,可以采用合并访问技术,这样就能复用那些 “免费” 获取的读取数据,可以运用算子融合技术,将多个操作合并执行,从而避免冗余的读写操作

可以将数据从全局内存迁移至共享内存,这样即使仍需读取数据,也能从速度更快的存储介质中获取,这就要用到 分块处理 的优化技巧了。最后,你还可以通过牺牲部分内存来换取其他可用计算资源,这正是 重计算策略 的核心思想,或者也可用通过牺牲内存来换取数值精度或稳定性,这正是 量化技术 的应用场景

因此我们拥有丰富的优化技巧工具箱来充分提升性能,优化手段可谓层出不穷,必须深刻理解内存在 GPU 性能中扮演的关键角色,这正是实现性能最大化的关键所在

5. Part 3: Using what we know to understand Flash Attention

那么现在我们要把这些内容整合起来讲,众所周知,FlashAttention 机制能显著加速注意力计算,相信大家都知道,这主要是通过 CUDA 内核的巧妙设计实现的,但其中的具体实现细节可能并非人尽皆知

论文 [Dao+ 2022] 指出,在未经优化的 PyTorch Transformer 实现中,注意力计算存在一个明显的性能瓶颈,通过内核融合等优化手段能实现显著的加速效果

在这里插入图片描述

论文中明确指出:我们采用了两项成熟技术—分块计算和重计算 以解决精确注意力计算和实现线性高带宽存储器访问的技术难题,这并非平方级计算复杂度,因为那在理论上是不可能实现的。

虽然注意力计算本身无法避免,但通过优化可以实现对高带宽内存(全局内存)的次平方级访问,这正是其核心突破所在。当内存成为瓶颈时,关键就在于将内存访问复杂度从平方级降下来,这样至少能让计算而非内存承担平方级的性能开销

简单回顾一下,到目前为止,大家应该反复实现过注意力机制了,整个计算过程将拆解为三个独立的矩阵乘法运算

在这里插入图片描述

整个流程包含 Q、K、V 三个矩阵,中间通过 softmax 函数进行连接,矩阵乘法运算本身其实相当直观,这完全可以通过分块技术来实现,我们之前演示过类似的案例,那么注意力机制的特殊之处在哪里呢?关键在于其中的 softmax 运算环节,这将成为真正的难点所在

在这里插入图片描述

正如前面所说,这里的矩阵乘法运算完全遵循我们讲过的优化方法,如果你查看 FlashAttention 论文中的图 1(上图所示),这本质上就是一个简单的分块矩阵乘法运算。可以看到 K 矩阵和 Q 矩阵都被分割成了若干小块,这些小块数据被复制到 SRAM 中进行乘法运算,运算结果先进行累加,随后被送入 HBM 执行 softmax 运算,最后再与 V 矩阵相乘

因此从 K、Q、V 矩阵乘法的角度来看,整个过程其实非常简单,但现在我们需要重点考虑 softmax 的实现,那么 softmax 运算究竟存在什么问题呢?这是一个全局操作,注意力机制中的 softmax 是逐行进行运算的,必须计算整行的总和才能得到 softmax 的归一化系数,这确实是个棘手的问题

如果采用分块计算,理想情况下所有运算都应该在分块内完成,我们最不希望看到的就是需要将数据写回大矩阵,因此需要一种能在每个分块内部实时计算的 softmax 方法,我们期望在每个分块内完成尽可能多的计算,这里的关键在于采用所谓的 online softmax 算法 [Milakov+ 2018]

那么这种算法究竟是什么呢?我们先来看传统的 softmax 流程:

在这里插入图片描述

当处理数据流时,传统的批量 softmax 操作需要先对 x 1 x_1 x1 x n x_n xn 进行指数运算,求和后再进行归一化处理,这就是标准 softmax 的计算方式。通常我们会先计算最大值,然后减去该值以确保数值计算的稳定性,上图展示的就是标准的数值稳定型 softmax 计算流程

而 online softmax 算法通过递推展开的方法,我们可以推导出某种论证关系

在这里插入图片描述

本质上,当前运行的归一化项与当前最大项 e x i − max ⁡ k = 1 V x k e^{x_i-\max_{k=1}^Vx_k} eximaxk=1Vxk 之间存在关联,因此,我们需要维护当前已遍历元素 x i x_i xi x j x_j xj(即当前迭代位置)的最大值。同时,我们还需要维护 e m j − 1 − m j e^{m_{j-1}-m_j} emj1mj 这个修正项,这个修正项本质上会校正我们的最大值,然后我们还会在后面添加新的项。

因此,这里的 d j d_j dj 将实时更新,最终可以计算出归一化因子,从而得到所需的归一化结果 y i y_i yi,其中 d V d_V dV 本身就是我们所需的归一化项。关键在于这个过程可以实时完成,我们不需要预先获取 x i x_i xi x n x_n xn 的所有数据,只需持续接收从 x i x_i xi x n x_n xn 的数据流即可,这正是关键所在

现在我们可以分块计算 softmax 了,在每个分块内部运算该算法即可计算出该分块的局部 softmax 值,随后,如有需要,可以将跟踪的所有组件写回存储,这正是完成该计算所需的全部条件。因此,在计算 softmax 时,我们完全不需要实例化整个 n 2 n^2 n2 规模的矩阵,核心原理就是如此

只要掌握这些要点,将它们整合起来,就能实现 FlashAttention 的前向传播过程。如果你去研读 FlashAttention 的第二篇论文 [Dao 2023],这正是作业 2 中要求实现的算法,你将按照以下步骤逐一实现:

在这里插入图片描述

首先,你需要完成 KQ 矩阵乘法运算,这个运算过程将采用分块处理,然后对这些分块执行乘法运算,那么该如何计算其中的 softmax 呢?我们将通过维护这些指数求和项的累加值来计算,然后通过增量式更新并针对最大值项进行修正,通过这种方式可以分块计算所有必要量值,逐块处理完毕后,最终只需与包含 V 的分块再进行一次乘法运算即可,这样就能得到完整的 softmax 输出结果

反向传播部分这里就不展开讲了,你可以通过逐块重计算的方式,避免存储 softmax 中间结果,切记,我们必须始终避免存储任何规模为 n 2 n^2 n2 的中间结果。因此我们在上面计算 softmax 时采用了巧妙的分块策略,使得在计算时完全不需要存储任何 n 2 n^2 n2 规模的中间数据

但在反向传播过程中,若存储激活值,这本身就会产生 n 2 n^2 n2 规模的数据,因此我们必须避免存储规模为 n 2 n^2 n2 的激活值。在反向传播过程中,我们将不得不采用逐块实时重计算的策略,这正是反向传播的关键技巧所在。除此之外,其他部分都遵循常规做法,这与梯度计算原理相同,只是采用了分块计算的方式逐块处理

OK,以上就是本次内容的全部要点

6. Recap for the whole lecture

相信大家已经理解了上面所讲解了各个关键技术点,包括分块计算、内存合并访问以及重计算策略,这些方法有机结合造就了 FlashAttention 这样的突破性成果。它们的共同作用使得 Transformer 模型的运行效率获得显著提升

让我们回顾本次课程的核心内容,硬件技术正是驱动当今所有语言模型发展的核心动力,因此,若想充分发挥硬件性能就必须深入理解底层实现细节,当前所有系统级的突破本质上都运用了上面我们探讨的这些核心概念

而当前 GPU 的扩展趋势图,实际上强烈促使我们优先考虑内存传输优化问题,内存传输才是整个系统的性能瓶颈。因此不能仅关注如何减少浮点运算次数这种单一维度,关键在于必须深入思考如何实现更高效的内存传输。

最终,当必须执行特定计算量时,真正的优化之道在于,通过优化数据流动路径,尽可能减少高带宽内存与全局内存之间的数据传输,将数据尽可能保留在超高速的共享内存中,从而最大限度减少这类传输,这正是实现类似 FlashAttention 等算法优异性能的关键所在

OK,以上就是本次讲座的全部内容了

结语

本讲我们主要讲解了 GPU 的体系架构、执行与内存模型、性能优化技巧以及 FlashAttention 的底层实现原理。

在 GPU 架构原理小节中,我们首先从 CPU 与 GPU 的设计哲学差异出发,指出 GPU 以极致的并行化和数据吞吐为核心目标,它是多个流式多处理器(SM)组成的 SIMT(单指令多线程) 模型。接着进一步深入讲解了 GPU 的执行层级,包括线程块(block)、线程束(warp)与线程(thread)的层次结构,并阐述了分层内存体系的访问特征与速度差异,强调了内存访问延迟才是 GPU 性能瓶颈的根源

在性能优化技巧小节,我们系统总结了影响 GPU 性能的关键因素及优化思路。首先强调避免 warp 内分支发散,以防线程间执行路径不一致造成的效率损失;其次讲解了混合精度计算的重要性,通过 FP16 输入与 FP32 累加可同时兼顾吞吐与稳定性;第三是算子融合(Operator Fusion),通过将多次核函数调用合并为单次执行,显著减少显存读写;第四是重计算(Recomputation)策略,以额外计算换取内存访问减少;最后讲解了内存合并访问(Memory Coalescing)与分块计算(Tiling)两项关键技术,通过顺序访问与共享内存分块提升带宽利用率。

我们还对性能曲线图做了详细分析,揭示了矩阵乘法性能出现周期性波动与断崖式下降的底层原因——包括分块尺寸与流式多处理器数量的不整除导致的 “波阵量化效应(wave quantization)”,以及内存突发访问区块错位引发的带宽浪费问题

最后小节对 FlashAttention 的具体实现进行了解析,我们将所有优化理念融会贯通,讲解了其如何结合分块计算与重计算以降低全局内存访问复杂度,并采用 online softmax 实现流式归一化,从而将注意力计算的内存访问开销从平方级降至次平方级,实现了真正意义上的显著加速。

整个讲解非常通俗易懂,大家感兴趣的可以看看

下节课我们将详细讨论如何为 GPU 编写高性能代码,敬请期待🤗

参考

  • https://stanford-cs336.github.io/spring2025
  • https://www.youtube.com/playlist?list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_
  • https://github.com/stanford-cs336/spring2025-lectures
  • https://github.com/stanford-cs336/spring2025-lectures/tree/main/nonexecutable/5-GPUs.pdf
http://www.dtcms.com/a/503638.html

相关文章:

  • 做淘宝需要的网站手机网站建设平台
  • 密码学和分布式账本
  • Web后端登录认证(会话技术)
  • 网络安全 | SSL/TLS 证书文件格式详解:PEM、CRT、CER、DER、PKI、PKCS12
  • uploads-labs靶场通关(2)
  • wordpress 企业建站小程序模板源码免费
  • Linux中页表缓存初始化pgtable_cache_init函数的实现
  • 量子计算机会普及个人使用吗?
  • 嵌入式入门:APP+BSP+HAL 三层分级架构浅析
  • 使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 19--测试框架Pytest基础 3--前后置操作应用
  • 面试面试面试
  • 北京响应式的网站下载了模板如何做网站
  • 中山企业营销型网站制作wordpress亲你迷路了
  • 个人做电影网站有什么风险南山最专业的网站建设
  • 「用Python来学微积分」4. 极坐标方程与参数方程
  • 第六章 路由基础
  • P1049 装箱问题 题解(四种方法)附DP和DFS的对比
  • Windows下Vscode连接到WSL的方法
  • R语言系列入门教程:什么是R语言?与传统编程语言有什么区别?
  • 商务网站建设的主流程网页设计排版作品分析
  • Altium Designer(AD24)原理图菜单栏详细介绍
  • 【JavaWeb学习】关于mysql-connector-j版本过高引起的问题
  • Eudemon1000E-F_V600R024C00SPC100
  • 建设工程资质录入是在那个网站机械类网站模板
  • 手机网站建站用哪个软件好字体样式 网站
  • ESMO中国之声丨徐兵河教授:芦康沙妥珠单抗再奏ADC中国之声,HR阳性HER2阴性晚期乳腺癌迎来CDK4/6抑制剂治疗后新希望
  • 模板网站禁止右键wordpress描述代码
  • pyhton(大厂笔试/面试)最长子序列(哈希-回溯-中等)含源码(二十三)
  • 做淘宝浏览单的网站菏泽外贸网站建设公司
  • Linux:理解操作系统和进程