sageattention低比特量化注意力机制,比FlashAttention快5 倍
SageAttention 是 清华大学 陈键飞团队开发的一种低比特量化注意力机制,主要用于加速大模型推理和训练。以下是其核心技术和应用:
技术特点
- 低比特量化:通过 FP4量化 (4比特)实现推理加速,相比传统 FlashAttention 提速5倍,同时保持精度。
- 可训练性:首次支持8比特训练,在微调任务中保持与全精度注意力相同的结果。
- 模块化设计:支持即插即用,可轻松集成到 PyTorch 、 TensorFlow 等框架中。
应用场景
- 大模型优化:广泛应用于 视频生成 、 自然语言处理 、 推荐系统 等领域,如 HunyuanVideo 、 CogVideoX 等模型。
- 硬件适配:在 英伟达RTX 5090 等GPU上实现高效运算,例如RTX 5090上达到1040万亿次每秒运算(TOPS)。
- 最新进展
2025年6月发布的 SageAttention3 进一步优化了量化策略,采用两级量化方法(先归一化后微缩)和动态调整量化范围,避免块内异常值影响精度。该版本在推理中保持精度优势,同时支持训练加速。
实现了 5 倍相比于 FlashAttention 的即插即用的推理加速(此前的 SageAttention V1/V2/V2++ 分别达到了 2.1,3,3.9 倍的加速效果),比如在 RTX 5090 上,SageAttention3 达到了 1040 TOPS 的速度,甚至是比 RTX 5090 昂贵十几倍的 H100 上使用 Hopper 独有的 FlashAttention3 还要快 1.65 倍!SageAttention3 在多种视频和图像生成等大模型上(包括 HunyuanVideo,CogVideoX,Mochi 和各类图像生成模型)均保持了端到端的精度表现。同时还首次提出可训练的 8 比特注意力(SageBwd)用于大模型的训练加速(注:FlashAttention3 的 FP8 版本也只支持前向传播),在各项微调任务中均保持了与全精度注意力相同的结果。
基础介绍
SageAttention 是一种专门针对 Transformer 注意力机制进行低比特量化(如 8-bit、4-bit)优化的算法库,目的在于以更低的计算资源、更小的模型延迟,同时保持精度与 FlashAttention、xFormers 等高性能库相当或更优。
SageAttention(v1)
SAGEATTENTION: ACCURATE 8-BIT ATTENTION FOR PLUG-AND-PLAY INFERENCE ACCELERATION, ICLR2025,
- 对k进行平滑(减去均值),在提升精度的同时增加不到0.2%的效率开销。
- 提出了自适应量化。总共设置了4种不同的sageattn,
(1)对q,k使用per-block或者per-token量化、
(2)SAGEAttn-B对于鄋的模型已经足够准确,同时能够实现2x的加速。SAGEAttn-vB在模型的模型层同样也准确同时比SAGEAttn-B快4%。因此,我们使用各种输入来测试模型每层的SAGEAttn-vB的余弦相似度。然后,我们将选择余弦相似度大于99.8%(SAGEAttn-B的最低相似度)的那些层为SAGEAttn-vB,而其他层则留给SAGEAttn-B。
最终实现比FlashAttention2快2x的加速比。
SageAttention 2
SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization,
- 在线程级别对矩阵(Q,K)进行量化,并将(P,V)量化为FP8。
- 提出了一种平滑Q的方法,以提高QK⊤的准确性。(从图中可以看到q,k,v均有一个平滑操作)。
- 第三,提出了一种两级累积策略,以增强FP8-PV的准确性。
- SageAttention2的速度比FlashAttention2和xformers快大约3倍和4.5倍。此外,SageAttention2在Hopper GPU上与FlashAttention3(fp8)的速度相匹配,但提供显著更高的准确性。
SageAttention 2++
SageAttention2++: A More Efficient Implementation of SageAttention2,
在 v2 基础上将PV的累加修改为fp16,实现比FlashAttention2快3.9×的推理效率,同时精度几乎无损。
SageAttention 3
SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of 8-Bit Training,
设计了SageAttention3,这是第一个用于推理加速的mxFP4注意力,在RTX5090上达到了1038 TOPS,比RTX5090上最快的FlashAttention快5倍。实验表明,SageAttention3能够加速各种模型,而不会造成端到端质量指标的降低。
其次,引入了第一个可训练的8位注意力(SageBwd),用于训练加速,并探讨其在训练任务中的可行性。我们发现8位注意力在微调任务中能够实现无损性能,但在预训练任务中目前有一些限制。