当前位置: 首页 > news >正文

长上下文能力:FlashAttention vs. RingAttention

FlashAttention

FlashAttention-1

FlashAttention算法解决了什么问题?解决方法是什么?效果如何?

问题和瓶颈

标准注意力算法 Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V 的主要瓶颈在于两点:

  • 计算与空间复杂度高: 该算法的计算复杂度为 O(N2)O(N^2)O(N2)。这意味着随着输入序列长度 NNN 的增加,所需的计算量(FLOPs)和显存占用均呈二次方增长。这严重限制了大模型处理长上下文的能力,并显著增加了对计算资源(尤其是显存)的需求。

  • HBM 读写开销大: 算法需要计算并存储大小为 N×NN \times NN×N 的注意力分数矩阵。后续的 softmax 归一化以及与值矩阵 VVV 的相乘操作,都需要频繁地将这个庞大的中间矩阵从显存(HBM)中读取到计算单元并写回。当输入序列变长时,这种密集的 HBM 读写操作成为显著的计算效率瓶颈。
    在这里插入图片描述
    在这里插入图片描述

核心优化

针对上述瓶颈,FlashAttention 做了三方面优化:

  • 分块计算 (Tiling): 将大型矩阵计算分解为适配 SRAM 的小块并行处理。
  • 在线 Softmax (Online Softmax): 将传统的多轮 Softmax 计算优化为单轮循环完成。
  • 算子融合 (Kernel Fusion): 在 SRAM 中一次性完成矩阵乘、掩码、Softmax 和 Dropout 等连续操作。
    在这里插入图片描述

这些优化带来了两个『降低』:

  • 显著降低 HBM 访问次数: 避免了读写庞大的 N×NN \times NN×N 注意力矩阵,将 HBM 访问次数从标准实现的 O(N2)O(N^2)O(N2) 降低至 O(N2d/M)O(N^2 d / M)O(N2d/M)(其中 MMM 是 SRAM 大小)。由于 ddd(特征维度)远小于 NNN(序列长度)且 MMM 固定,访问次数实现了数量级的降低。
  • 大幅减少显存占用: 无需存储 N×NN \times NN×N 中间矩阵,峰值显存占用从 O(N2)O(N^2)O(N2) 降至 O(N)O(N)O(N)(主要用于存储最终输出和必要的 Softmax 统计量),显著提升了处理长序列的能力。

伪代码:outer loop是K-split scheme的并行模式,不是最优的,在FlashAttention2中会改用Q-split scheme实现更高效的并行度。
在这里插入图片描述

效果评测

  • 训练效率:在GPT-2的训练效率上,FlashAttention对比HuggingFace的实现提升3倍,对比Megatron的标准attention,提升1.8倍。
  • 质量上,FlashAttention不是做近似计算,所以没有损失,且因为能支持更长文本的训练而效果更高。

FlashAttention-2:结合硬件优化

问题和瓶颈

FlashAttention 的作者 Tri Dao 指出,尽管 FlashAttention 大幅提升了注意力计算效率,但其速度仍远低于优化后的矩阵乘法 (GEMM),仅能达到理论峰值 FLOP/s 的 25%-40%。造成这一差距的主要原因在于:

  • 并行策略欠佳: 原算法基于 K 维度切分 (K-split) 的外层循环并非最优选择,因为只对QKTQK^TQKT 的矩阵乘法做到了并行,后面的softmax与V的乘法都无法支持并行。
  • 非 GEMM 操作开销显著: Attention 计算过程中,Softmax、Rescaling、Masking 等非矩阵乘法 (non-GEMM) 操作的开销占比仍然过高,限制了整体吞吐。

GPU 执行介绍(优化背景):

理解 FlashAttention-2 的优化需要回顾 GPU 的执行模型:

  • GPU 通过大量并发线程(组织在 thread block 中)执行计算(kernel)。
  • 流式多处理器 (SM) 是实际的执行单元(如 A100 有 108 个 SM),负责调度和执行分配给它的线程块 (thread block)。
  • 每个 SM 将线程块进一步划分为 warp(通常 32 线程/warp)。warp 是 GPU 的基本调度和执行单位:同 warp 内的线程高度同步,可直接协作(如执行 Tensor Core GEMM);不同 warp 间通过共享内存 (Shared Memory/SRAM) 通信。
  • SM 的 Warp Scheduler 在每个周期选择就绪的 warp 发射指令(A100 每个 SM 支持 4 个并发 warp 调度器,意味着更高的指令级并行潜力)。
  • 典型的 kernel 流程:从 HBM 加载数据到寄存器/SRAM → 计算 → 写回结果到 HBM。

核心优化

针对FlashAttention-1上述瓶颈,FlashAttention-2 在算法和实现上进行了关键改进:

  • 最大化 GEMM 占比,减少非GEMM操作:
    • 延迟Rescaling: 将输出 (O) 的局部归一化 (rescaling) 操作推迟到最后一步全局执行一次,避免了在中间计算块上的重复 rescaling。
    • 合并反向传播统计量: 在反向传播中,将需要保留的两个中间统计量(最大值 m 和指数和 l)合并存储为一个,减少了内存占用和计算量。
  • 提高并行度: FlashAttention-1 中,一个线程块 (thread block) 负责处理一个注意力头 (head),整体并行度为 batch_size * num_heads。当此值大于可用 SM 数量(如 A100 的 108)时,资源利用率较高。然而,对于长序列输入,batch_size 往往较小,导致并行度不足,SM 利用率低下。所以FlashAttention-2 引入了在序列长度维度 (sequence length) 上的切分:
    • 前向传播: 沿查询 (Q) 的行维度进行切分。不同块 (tile) 的计算相互独立,天然具有高并行度。
    • 反向传播: 采用双重并行 (dual-pass parallelization) 策略,同时沿序列的行和列方向进行切分。这是因为梯度计算存在跨块依赖 (例如 dQi=∑jdSijKjTdQ_i = \sum_j dS_{ij} K_j^TdQi=jdSijKjT)。该策略精心设计分块计算顺序和通信,确保依赖关系在块内处理或高效同步,从而最大化利用 SM 资源,尤其在小 batch_size 场景下显著提升并行度。
  • work partitioning (工作划分):目标是最大化 GPU SM 的利用率,实现更好的负载均衡 (load balancing),并减少线程块 (thread block) 内部以及线程块之间的同步开销。Q维度切块或双重维度切块,每块大小多少,任务怎么分配等,从多个角色做综合优化,使资源利用率最高,吞吐效率最高。

效果评测

  • 速度提升:比FlashAttention快1.7-3.0倍,比Triton实现的FlashAttention快1.3-2.5倍,比标准注意力实现快3-10倍。
  • 计算效率:在A100 GPU上达到230 TFLOPs/s,占理论峰值的73%。
  • 端到端训练加速:在1.3B/2.7B参数的GPT类模型(序列长度2k/8k)训练中,相比FlashAttention提速1.3倍,相比无FlashAttention的基线提速2.8倍,单A100 GPU算力利用率达225 TFLOPs/s(72%)。

RingAttention

简介

FlashAttention 优化单卡注意力计算,RingAttention 则通过多机多卡分布式计算解决超长序列的显存与算力瓶颈。其核心融合序列并行、环形通信与分块计算,实现显存高效的超长上下文处理。

核心优化

下面简单介绍一下分块计算和ring attention (序列并行+环形通信):

  • Blockwise Parallel Transformer: 受 Online Softmax 与 FlashAttention 工作的启发,作者提出了一种基于查询维度(Q)分块的自注意力并行计算方法。该方法将查询序列划分为块,并并行处理每个查询块。对于单个查询块,其注意力输出 Attention(Q_block, K, V) 通过遍历所有键值块(K, V blocks) 计算得到。关键创新在于:计算过程中,利用局部块归一化常数与全局 Softmax 归一化常数的差值,对每个查询块的局部注意力结果进行缩放校正,从而直接获得等效于全局注意力矩阵的结果。这一机制完全避免了存储庞大的 N×N 注意力矩阵,显著降低了显存需求。类似地,前馈网络(FFN)的计算亦可基于分块的查询和其对应的局部注意力输出直接进行。
  • Ring Attention: 该方法沿查询维度(Q)对序列进行分块,并将每个查询块(Q_chunk)分配到不同的计算设备上。键值序列(KV)同样被分块,并以环状(Ring)拓扑结构在设备间循环传输。每个设备利用本地持有的查询块,依次处理循环流入的每个键值块(KV_chunk),逐步完成该查询块对应输出(O)的计算。设备内部的注意力计算可高效实现(如采用 FlashAttention 或其他优化方法)。关键优势在于其计算与通信的高度重叠:当一个设备在计算当前 KV 块时,下一个 KV 块已在通信传输中,从而显著提升了整体计算效率。
    在这里插入图片描述
    RingAttention的代码逻辑:
    在这里插入图片描述

效果评测

从下图可见,效果不言而预,可处理的文本长度能够达到百万级tokens。例如,32张A100 GPU的资源下,训练7B模型, 可实现超过100万tokens的上下文长度,对比先前最佳水平提升了32倍。
在这里插入图片描述

References

  • Flash attention papers
  • Block Parallel Transformer paper
  • RingAttention paper
  • ring attention explained: https://coconut-mode.com/posts/ring-attention/
http://www.dtcms.com/a/277921.html

相关文章:

  • 协程的 callbackFlow 函数的使用和原理
  • 认识数据分析
  • 第一,二次作业
  • LAN-401 linux操作系统的移植
  • DHS及HTTPS工作过程
  • 【Claude Code】 AI 编程指南
  • sql初学见解
  • 多线程死锁
  • 飞算Java AI开发助手:引领智能编程新风尚
  • Llama系列:Llama1, Llama2,Llama3内容概述
  • 【读书笔记】《C++ Software Design》第九章:The Decorator Design Pattern
  • HTML 基本骨架
  • [GWCTF 2019]我有一个数据库
  • SOMEIP协议与测试
  • LeetCode 2401.最长优雅子数组
  • C++数组指针与函数指针
  • 为什么要有延时回调?
  • 2024-2025-2 山东大学《软件工程与实践》期末(回忆版)
  • p4 大小写检查
  • C++高级编程,类模版成员函数类外实现
  • windows10如何安装vue开发环境
  • JAVA-springboot 整合Activemq
  • ECU(电子控制单元)是什么?
  • C++中顶层const与底层const
  • JSX 语法
  • 【前端知识】移动端APP原生应用与H5交互底层逻辑
  • Dubbo跨越分布式事务的最终一致性陷阱
  • 有效感受野(ERF)可视化工具
  • hash表的模拟--开放定址法
  • 如何将本地代码同步到远程Github仓库