长上下文能力:FlashAttention vs. RingAttention
FlashAttention
FlashAttention-1
FlashAttention算法解决了什么问题?解决方法是什么?效果如何?
问题和瓶颈
标准注意力算法 Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V 的主要瓶颈在于两点:
-
计算与空间复杂度高: 该算法的计算复杂度为 O(N2)O(N^2)O(N2)。这意味着随着输入序列长度 NNN 的增加,所需的计算量(FLOPs)和显存占用均呈二次方增长。这严重限制了大模型处理长上下文的能力,并显著增加了对计算资源(尤其是显存)的需求。
-
HBM 读写开销大: 算法需要计算并存储大小为 N×NN \times NN×N 的注意力分数矩阵。后续的 softmax 归一化以及与值矩阵 VVV 的相乘操作,都需要频繁地将这个庞大的中间矩阵从显存(HBM)中读取到计算单元并写回。当输入序列变长时,这种密集的 HBM 读写操作成为显著的计算效率瓶颈。
核心优化
针对上述瓶颈,FlashAttention 做了三方面优化:
- 分块计算 (Tiling): 将大型矩阵计算分解为适配 SRAM 的小块并行处理。
- 在线 Softmax (Online Softmax): 将传统的多轮 Softmax 计算优化为单轮循环完成。
- 算子融合 (Kernel Fusion): 在 SRAM 中一次性完成矩阵乘、掩码、Softmax 和 Dropout 等连续操作。
这些优化带来了两个『降低』:
- 显著降低 HBM 访问次数: 避免了读写庞大的 N×NN \times NN×N 注意力矩阵,将 HBM 访问次数从标准实现的 O(N2)O(N^2)O(N2) 降低至 O(N2d/M)O(N^2 d / M)O(N2d/M)(其中 MMM 是 SRAM 大小)。由于 ddd(特征维度)远小于 NNN(序列长度)且 MMM 固定,访问次数实现了数量级的降低。
- 大幅减少显存占用: 无需存储 N×NN \times NN×N 中间矩阵,峰值显存占用从 O(N2)O(N^2)O(N2) 降至 O(N)O(N)O(N)(主要用于存储最终输出和必要的 Softmax 统计量),显著提升了处理长序列的能力。
伪代码:outer loop是K-split scheme的并行模式,不是最优的,在FlashAttention2中会改用Q-split scheme实现更高效的并行度。
效果评测
- 训练效率:在GPT-2的训练效率上,FlashAttention对比HuggingFace的实现提升3倍,对比Megatron的标准attention,提升1.8倍。
- 质量上,FlashAttention不是做近似计算,所以没有损失,且因为能支持更长文本的训练而效果更高。
FlashAttention-2:结合硬件优化
问题和瓶颈
FlashAttention 的作者 Tri Dao 指出,尽管 FlashAttention 大幅提升了注意力计算效率,但其速度仍远低于优化后的矩阵乘法 (GEMM),仅能达到理论峰值 FLOP/s 的 25%-40%。造成这一差距的主要原因在于:
- 并行策略欠佳: 原算法基于 K 维度切分 (K-split) 的外层循环并非最优选择,因为只对QKTQK^TQKT 的矩阵乘法做到了并行,后面的softmax与V的乘法都无法支持并行。
- 非 GEMM 操作开销显著: Attention 计算过程中,Softmax、Rescaling、Masking 等非矩阵乘法 (non-GEMM) 操作的开销占比仍然过高,限制了整体吞吐。
GPU 执行介绍(优化背景):
理解 FlashAttention-2 的优化需要回顾 GPU 的执行模型:
- GPU 通过大量并发线程(组织在 thread block 中)执行计算(kernel)。
- 流式多处理器 (SM) 是实际的执行单元(如 A100 有 108 个 SM),负责调度和执行分配给它的线程块 (thread block)。
- 每个 SM 将线程块进一步划分为 warp(通常 32 线程/warp)。warp 是 GPU 的基本调度和执行单位:同 warp 内的线程高度同步,可直接协作(如执行 Tensor Core GEMM);不同 warp 间通过共享内存 (Shared Memory/SRAM) 通信。
- SM 的 Warp Scheduler 在每个周期选择就绪的 warp 发射指令(A100 每个 SM 支持 4 个并发 warp 调度器,意味着更高的指令级并行潜力)。
- 典型的 kernel 流程:从 HBM 加载数据到寄存器/SRAM → 计算 → 写回结果到 HBM。
核心优化
针对FlashAttention-1上述瓶颈,FlashAttention-2 在算法和实现上进行了关键改进:
- 最大化 GEMM 占比,减少非GEMM操作:
- 延迟Rescaling: 将输出 (O) 的局部归一化 (rescaling) 操作推迟到最后一步全局执行一次,避免了在中间计算块上的重复 rescaling。
- 合并反向传播统计量: 在反向传播中,将需要保留的两个中间统计量(最大值 m 和指数和 l)合并存储为一个,减少了内存占用和计算量。
- 提高并行度: FlashAttention-1 中,一个线程块 (thread block) 负责处理一个注意力头 (head),整体并行度为 batch_size * num_heads。当此值大于可用 SM 数量(如 A100 的 108)时,资源利用率较高。然而,对于长序列输入,batch_size 往往较小,导致并行度不足,SM 利用率低下。所以FlashAttention-2 引入了在序列长度维度 (sequence length) 上的切分:
- 前向传播: 沿查询 (Q) 的行维度进行切分。不同块 (tile) 的计算相互独立,天然具有高并行度。
- 反向传播: 采用双重并行 (dual-pass parallelization) 策略,同时沿序列的行和列方向进行切分。这是因为梯度计算存在跨块依赖 (例如 dQi=∑jdSijKjTdQ_i = \sum_j dS_{ij} K_j^TdQi=∑jdSijKjT)。该策略精心设计分块计算顺序和通信,确保依赖关系在块内处理或高效同步,从而最大化利用 SM 资源,尤其在小 batch_size 场景下显著提升并行度。
- work partitioning (工作划分):目标是最大化 GPU SM 的利用率,实现更好的负载均衡 (load balancing),并减少线程块 (thread block) 内部以及线程块之间的同步开销。Q维度切块或双重维度切块,每块大小多少,任务怎么分配等,从多个角色做综合优化,使资源利用率最高,吞吐效率最高。
效果评测
- 速度提升:比FlashAttention快1.7-3.0倍,比Triton实现的FlashAttention快1.3-2.5倍,比标准注意力实现快3-10倍。
- 计算效率:在A100 GPU上达到230 TFLOPs/s,占理论峰值的73%。
- 端到端训练加速:在1.3B/2.7B参数的GPT类模型(序列长度2k/8k)训练中,相比FlashAttention提速1.3倍,相比无FlashAttention的基线提速2.8倍,单A100 GPU算力利用率达225 TFLOPs/s(72%)。
RingAttention
简介
FlashAttention 优化单卡注意力计算,RingAttention 则通过多机多卡分布式计算解决超长序列的显存与算力瓶颈。其核心融合序列并行、环形通信与分块计算,实现显存高效的超长上下文处理。
核心优化
下面简单介绍一下分块计算和ring attention (序列并行+环形通信):
- Blockwise Parallel Transformer: 受 Online Softmax 与 FlashAttention 工作的启发,作者提出了一种基于查询维度(Q)分块的自注意力并行计算方法。该方法将查询序列划分为块,并并行处理每个查询块。对于单个查询块,其注意力输出 Attention(Q_block, K, V) 通过遍历所有键值块(K, V blocks) 计算得到。关键创新在于:计算过程中,利用局部块归一化常数与全局 Softmax 归一化常数的差值,对每个查询块的局部注意力结果进行缩放校正,从而直接获得等效于全局注意力矩阵的结果。这一机制完全避免了存储庞大的 N×N 注意力矩阵,显著降低了显存需求。类似地,前馈网络(FFN)的计算亦可基于分块的查询和其对应的局部注意力输出直接进行。
- Ring Attention: 该方法沿查询维度(Q)对序列进行分块,并将每个查询块(Q_chunk)分配到不同的计算设备上。键值序列(KV)同样被分块,并以环状(Ring)拓扑结构在设备间循环传输。每个设备利用本地持有的查询块,依次处理循环流入的每个键值块(KV_chunk),逐步完成该查询块对应输出(O)的计算。设备内部的注意力计算可高效实现(如采用 FlashAttention 或其他优化方法)。关键优势在于其计算与通信的高度重叠:当一个设备在计算当前 KV 块时,下一个 KV 块已在通信传输中,从而显著提升了整体计算效率。
RingAttention的代码逻辑:
效果评测
从下图可见,效果不言而预,可处理的文本长度能够达到百万级tokens。例如,32张A100 GPU的资源下,训练7B模型, 可实现超过100万tokens的上下文长度,对比先前最佳水平提升了32倍。
References
- Flash attention papers
- Block Parallel Transformer paper
- RingAttention paper
- ring attention explained: https://coconut-mode.com/posts/ring-attention/