FlashInfer - SparseAttention(稀疏注意力)只计算部分有意义的注意力连接,而非全部 token 对
FlashInfer - SparseAttention(稀疏注意力)只计算部分有意义的注意力连接,而非全部 token 对
flyfish
SparseAttention 原理
SparseAttention(稀疏注意力)是针对标准自注意力机制的优化,核心目标是将 O(N²) 的计算复杂度降低到 O(N·k)(k 为稀疏度),特别适用于长序列(如 N>10k)。其核心思想是:只计算部分有意义的注意力连接,而非全部 token 对。
在算法复杂度分析中,线性(Linear) 和 亚线性(Sublinear) 是描述算法效率随输入规模增长的术语,与平方级(Quadratic)增长形成对比。
1. 核心定义
线性复杂度(O(N))
- 定义:算法的计算量或内存开销与输入规模 N N N 成正比例关系。
- 特点:当 N N N 翻倍时,计算量/内存也翻倍。
- 示例:
- 遍历数组求和:每个元素仅访问一次,总操作数为 N N N。
- 稀疏注意力的局部窗口模式(每个token仅关注前后 k k k 个token):总操作数为 N ⋅ k N \cdot k N⋅k,当 k k k 固定时,复杂度为 O ( N ) O(N) O(N)。
亚线性复杂度(O(N^α),其中 α < 1)
- 定义:算法的计算量或内存开销增长慢于线性。
- 特点:当 N N N 翻倍时,计算量/内存增长小于两倍。
- 常见形式:
- O ( N ) O(\sqrt{N}) O(N):如分块算法。
- O ( N log N ) O(N \log N) O(NlogN):如快速排序、某些稀疏注意力的动态选择策略。
2. 与平方级复杂度(O(N²))的对比
复杂度 | 含义 | 增长速度(N增大时) | 适用场景 |
---|---|---|---|
O(N²) | 平方级:计算量与 N 2 N^2 N2 成正比 | 极快(N翻倍时计算量×4) | 短序列(N<1000) |
O(N) | 线性:计算量与 N N N 成正比 | 中等(N翻倍时计算量×2) | 长序列(N>10000) |
O(N^α) | 亚线性:α<1(如O(√N)、O(N logN)) | 最慢(N翻倍时增长<2倍) | 超长序列(N>1M)或特殊场景 |
3. 稀疏注意力中的线性/亚线性优化
(1)局部窗口稀疏(Linear)
- 模式:每个token仅关注前后 k k k 个token(如 k = 512 k=512 k=512)。
- 复杂度:总操作数 N ⋅ k N \cdot k N⋅k,当 k k k 固定时为 O ( N ) O(N) O(N)。
计算量仅与序列长度 N N N 线性相关,而非 N 2 N^2 N2。
(2)动态Top-K选择(Sublinear)
- 模式:每个token仅关注注意力得分最高的 K K K 个token(如 K = 256 K=256 K=256)。
- 复杂度:若使用高效选择算法(如堆排序),总操作数为 O ( N log K ) ≈ O ( N log N ) O(N \log K) \approx O(N \log N) O(NlogK)≈O(NlogN)(亚线性)。
仅计算关键连接,避免冗余计算。
4. FlashInfer中的线性优化
FlashInfer通过以下方式实现线性或近似线性复杂度:
-
分页KV缓存(PageAttention)
- 将KV缓存分成固定大小的页(如每页1024个token),仅激活当前相关的页。
- 计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降至 O ( N ⋅ p ) O(N \cdot p) O(N⋅p)( p p p 为活跃页数,通常 p ≪ N p \ll N p≪N)。
-
与稀疏注意力库集成
- 结合
xformers
等库,支持局部窗口、扩张窗口等稀疏模式,直接将计算量降至 O ( N ) O(N) O(N)。
- 结合
-
内存访问优化
- 通过连续内存布局和批量操作,减少每步计算的开销,进一步提升线性算法的效率。
5. 为什么线性/亚线性很重要?
在长序列场景(如 N = 100 K N=100K N=100K)中,平方级算法(如标准注意力)的计算和内存需求会爆炸式增长,而线性/亚线性算法仍能保持可扩展性:
序列长度 N N N | O(N²) 计算量 | O(N) 计算量 | 加速比 |
---|---|---|---|
1,000 | 1,000,000 | 1,000 | 1,000x |
10,000 | 100,000,000 | 10,000 | 10,000x |
100,000 | 10,000,000,000 | 100,000 | 100,000x |
可见,当 N N N 增大时,线性算法的优势愈发明显。这也是 FlashInfer 等框架支持稀疏注意力的核心原因——突破平方级瓶颈,让模型处理超长文本成为可能。
计算复杂度和内存开销呈平方级增长
在自注意力机制(Self-Attention)中,当序列长度 N N N 增大时,计算复杂度和内存开销呈平方级增长( O ( N 2 ) O(N^2) O(N2)),根本原因在于 注意力矩阵的规模是 N × N N \times N N×N
一、自注意力的核心计算步骤
自注意力的计算分为三步(以单头注意力为例):
-
计算查询-键矩阵(QK矩阵):
Score = Q K T 其中 Q , K , V ∈ R N × d \text{Score} = QK^T \quad \text{其中} \quad Q, K, V \in \mathbb{R}^{N \times d} Score=QKT其中Q,K,V∈RN×d
Q Q Q(查询)和 K K K(键)的形状均为 N × d N \times d N×d( d d d 是特征维度),它们的矩阵乘积得到 注意力分数矩阵 Score ∈ R N × N \text{Score} \in \mathbb{R}^{N \times N} Score∈RN×N。- 计算复杂度:每个元素的计算需要 d d d 次乘法,总共有 N 2 N^2 N2 个元素,复杂度为 O ( N 2 d ) O(N^2 d) O(N2d)。
- 内存开销:存储 Score \text{Score} Score 需要 N 2 N^2 N2 个浮点数(如FP32时占 4 N 2 4N^2 4N2 字节)。
-
Softmax归一化和值加权(QKV加权):
Attention = Softmax ( Score ) ⋅ V \text{Attention} = \text{Softmax}(\text{Score}) \cdot V Attention=Softmax(Score)⋅V- 对 N × N N \times N N×N 的矩阵做Softmax,复杂度为 O ( N 2 ) O(N^2) O(N2)。
- 矩阵与 V ∈ R N × d V \in \mathbb{R}^{N \times d} V∈RN×d 相乘,得到 N × d N \times d N×d 的输出,复杂度为 O ( N 2 d ) O(N^2 d) O(N2d)。
-
多头注意力扩展:
若有 h h h 个头,每个头独立计算,总复杂度变为 O ( h N 2 d ) O(hN^2 d) O(hN2d),但 h h h 和 d d d 通常是固定维度(如 h = 16 , d = 64 h=16, d=64 h=16,d=64),因此主导项仍是 O ( N 2 ) O(N^2) O(N2)。
二、平方级增长的本质原因
1. 计算复杂度的平方级来源
-
核心操作是矩阵乘法 Q K T QK^T QKT:
两个 N × d N \times d N×d 矩阵相乘得到 N × N N \times N N×N 矩阵,计算量为 N × d × N = N 2 d N \times d \times N = N^2 d N×d×N=N2d。
当 N N N 增大时, N 2 N^2 N2 是主导项(即使 d d d 固定,如 d = 1024 d=1024 d=1024, N = 4096 N=4096 N=4096 时 N 2 = 16 , 777 , 216 N^2 = 16,777,216 N2=16,777,216,远大于 N N N)。 -
每对token都需计算一次关联:
自注意力要求每个token(共 N N N 个)与所有其他token(包括自身,共 N N N 个)计算注意力分数,总共有 N × N = N 2 N \times N = N^2 N×N=N2 次两两交互。
2. 内存开销的平方级来源
-
存储注意力分数矩阵( N × N N \times N N×N):
假设使用FP32(4字节/数),存储该矩阵需要 4 N 2 4N^2 4N2 字节。- 当 N = 1024 N=1024 N=1024 时,需要约4MB;
- 当 N = 16 K N=16K N=16K 时,需要约1GB;
- 当 N = 32 K N=32K N=32K 时,需要约4GB(仅单个注意力头)。
这还不包括中间变量(如Softmax结果、梯度等),实际内存占用会更高。
-
隐含的内存放大效应:
矩阵运算(如PyTorch/TensorFlow的底层实现)需要额外的临时存储空间,进一步加剧内存压力。
三、举例说明
假设 d = 1024 d=1024 d=1024(固定维度),比较不同 N N N 时的计算量和内存:
序列长度 N N N | 注意力矩阵大小 | 计算量(乘加操作数) | 内存占用(FP32,单头) |
---|---|---|---|
1024 | 1024×1024 | ~1e9次 | ~4MB |
2048 | 2048×2048 | ~4e9次(×4) | ~16MB(×4) |
4096 | 4096×4096 | ~16e9次(×16) | ~64MB(×16) |
16K | 16K×16K | ~262e9次(×256) | ~1GB(×256) |
32K | 32K×32K | ~1e12次(×1024) | ~4GB(×1024) |
可见,当 N N N 翻倍时,计算量和内存占用均变为原来的4倍(平方级增长)。
四、为什么稀疏注意力能缓解这个问题?
稀疏注意力(如SparseAttention、FlashInfer支持的模式)通过 减少有效交互的token对数量(从 N 2 N^2 N2 降至 O ( N ⋅ r ) O(N \cdot r) O(N⋅r), r r r 是每个token关注的平均token数),将复杂度优化为线性或亚线性(如 O ( N r ) O(Nr) O(Nr), r ≪ N r \ll N r≪N)。例如:
- 局部窗口稀疏:每个token仅关注附近 r = 512 r=512 r=512 个token,总交互数为 N ⋅ 512 = O ( N ) N \cdot 512 = O(N) N⋅512=O(N)。
- Top-K稀疏:每个token仅关注得分最高的 K = 256 K=256 K=256 个token,总交互数为 N ⋅ K = O ( N K ) N \cdot K = O(NK) N⋅K=O(NK)。
有意义的注意力连接
在注意力机制中,“有意义的注意力连接” 指的是 对当前任务或输入内容有实际影响的 token 对。
“token 对”(Token Pair) 指的是 序列中任意两个 token 之间的交互关系。
“有意义的注意力连接”是一个与语义、句法、任务目标相关的概念,不同模型或场景可能采用不同的标准。在实际实现中,通常通过稀疏模式设计(如局部窗口、动态采样)或注意力得分过滤(如 Top-K、Top-P)来近似定义和计算这些有意义的连接,从而在保持模型性能的同时降低计算复杂度。
1. 基于语义关联的“有意义”
- 定义:两个 token 在语义上存在强关联,关注它们的关系能帮助模型更好地理解上下文。
- 示例:
- 输入句子:“济南是山东的省会,也是一座历史悠久的城市。”
- 注意力机制应关注“济南”与“城市”之间的连接,因为它们在语义上高度相关。
- 而“济南”与句首的“的”或句尾的句号之间的连接则可能被视为“无意义”。
2. 基于句法结构的“有意义”
- 定义:两个 token 在句法上存在依赖关系(如主语-谓语、动词-宾语)。
- 示例:
- 句子:“我喜欢吃苹果。”
- “喜欢”与“苹果”之间的连接是有意义的,因为它们构成动宾关系。
- 而“我”与“苹果”之间的直接连接可能较弱(需通过“喜欢”间接关联)。
3. 基于任务目标的“有意义”
- 定义:根据具体任务,某些 token 对的关系对预测结果更重要。
- 示例:
- 机器翻译:源语言与目标语言中对应的词对(如“apple”→“苹果”)。
- 问答系统:问题中的关键词与上下文中的答案候选。
- 命名实体识别:实体词与描述其属性的词(如“微软”与“公司”)。
4. 基于注意力得分的“有意义”
- 定义:通过模型计算出的注意力得分(如 softmax 后的权重)筛选出重要连接。
- 具体方法:
- Top-K 选择:保留得分最高的 K 个连接。
- 阈值过滤:丢弃得分低于阈值的连接(如仅保留得分 > 0.1 的连接)。
- 动态稀疏化:根据得分分布自适应确定稀疏度(如 Top-P 采样)。
5. 常见稀疏模式中的“有意义”设计
稀疏模式 | 如何定义“有意义” |
---|---|
局部窗口 | 认为当前 token 附近的 token 更相关(如前后 512 个 token)。 |
扩张窗口 | 保留局部相关性的同时,通过跳跃式采样捕获长距离依赖(类似 CNN 中的空洞卷积)。 |
基于内容的稀疏 | 使用额外的网络或启发式规则动态识别相关 token(如 Longformer 的 global attention)。 |
块稀疏 | 将序列分块,块内全连接,块间仅保留关键连接(如 BigBird 的 block sparse)。 |
6. FlashInfer 中的“有意义”实现
FlashInfer 本身不直接定义“有意义”,而是通过以下方式支持用户自定义或第三方库定义的稀疏模式:
- 灵活的掩码接口:允许用户传入自定义掩码矩阵,显式指定哪些连接需要计算。
# 示例:仅关注前 100 个 token mask = torch.zeros(seq_len, seq_len) mask[:, :100] = 1 # 每行仅关注前 100 列 output = flashinfer.attention_with_mask(q, k, v, mask)
- 与第三方稀疏注意力库集成:如
xformers
、torch-sparse
,利用它们的算法识别有意义的连接。 - 长序列优化:通过分页 KV 缓存(PageAttention),仅激活当前可能相关的 token 页。
SparseAttention 原理
SparseAttention(稀疏注意力)是针对标准自注意力机制的优化,核心目标是将 O(N²) 的计算复杂度降低到 O(N·k)(k 为稀疏度),特别适用于长序列(如 N>10k)。其核心思想是:只计算部分有意义的注意力连接,而非全部 token 对。
稀疏矩阵的定义与判定
1. 定义与核心特征
稀疏矩阵(Sparse Matrix)是指矩阵中绝大多数元素为零,仅有少量非零元素的矩阵。其核心特征是:
- 零元素占比极高,非零元素占比极低。
- 没有严格的数学阈值(如“必须超过X%的零元素”),但实际应用中,通常认为零元素占比超过 50% 即可视为稀疏矩阵,具体标准因领域而异(如数值计算、机器学习中可能要求更高的稀疏度,如70%以上)。
2. 分析
矩阵:
-
非零元素:9个
-
零元素:26个
-
总元素:9 + 26 = 35个
-
稀疏度(零元素占比): 26 35 ≈ 74.29 % \frac{26}{35} \approx 74.29\% 3526≈74.29%
-
密度(非零元素占比): 9 35 ≈ 25.71 % \frac{9}{35} \approx 25.71\% 359≈25.71%
-
属于稀疏矩阵。74%的零元素占比远超过“大部分”的基本要求(超过50%),符合稀疏矩阵的典型特征。
-
“绝大部分元素为零”中的“绝大部分”通常指超过一半(>50%),具体数值依场景而定。例如:
- 若矩阵为 $5 \times 7 = 35 ) 维,26个零元素(占比74%)显然属于“绝大部分”;
- 若矩阵规模更大(如 $100 \times 100 )),即使零元素占比60%,也可能被视为稀疏矩阵。
3. 关键补充
- 无严格统一标准:稀疏矩阵的判定是相对的,取决于应用场景。例如:
- 在科学计算中,零元素占比超过70%常被视为稀疏矩阵;
- 在某些极端场景(如社交网络邻接矩阵),零元素占比可能高达99%以上。
- 稀疏度与密度的计算:
- 稀疏度 = 零元素数量 总元素数量 × 100 % \frac{\text{零元素数量}}{\text{总元素数量}} \times 100\% 总元素数量零元素数量×100%
- 密度 = 1 − 稀疏度 = 非零元素数量 总元素数量 × 100 % 1 - \text{稀疏度} = \frac{\text{非零元素数量}}{\text{总元素数量}} \times 100\% 1−稀疏度=总元素数量非零元素数量×100%
稀疏度需结合矩阵规模和领域需求综合判断
1. 标准注意力的瓶颈
标准自注意力计算 Attention ( Q , K , V ) = softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V Attention(Q,K,V)=softmax(dQKT)V 需要计算所有 token 对之间的注意力得分,当序列长度 N N N 增大时,计算和内存开销呈平方级增长。
2. SparseAttention 的核心思路
通过设计稀疏模式(Sparsity Pattern),选择性地忽略部分 token 对之间的注意力计算:
- 固定模式:如局部窗口(只关注当前 token 前后的 k 个 token)、扩张窗口(类似 CNN 中的空洞卷积)。
- 动态模式:根据输入内容动态选择重要的 token 对(如基于注意力得分阈值)。
- 结构化稀疏:利用矩阵分解或低秩近似减少计算量。
3. 数学表达
标准注意力的得分矩阵 S = Q K T S = QK^T S=QKT 是稠密的,而稀疏注意力只计算其中非零元素:
S i , j = { q i T k j if ( i , j ) ∈ 稀疏模式 0 otherwise S_{i,j} = \begin{cases} q_i^T k_j & \text{if } (i,j) \in \text{稀疏模式} \\ 0 & \text{otherwise} \end{cases} Si,j={qiTkj0if (i,j)∈稀疏模式otherwise
其中,稀疏模式可通过掩码矩阵 M M M 表示:
S sparse = S ⊙ M S_{\text{sparse}} = S \odot M Ssparse=S⊙M
( ⊙ \odot ⊙ 表示逐元素乘法)。
4. 典型稀疏模式
模式 | 描述 | 复杂度 |
---|---|---|
局部窗口 | 每个 token 只关注前后固定窗口内的 token(如窗口大小=512)。 | O(N·k) |
扩张窗口 | 窗口内每隔一定步长采样 token,扩大感受野(类似 CNN 中的空洞卷积)。 | O(N·k/d) |
随机稀疏 | 随机选择固定比例的 token 对进行计算。 | O(N²·p) |
基于内容的稀疏 | 根据注意力得分动态选择 top-k 个连接(如 Longformer 的 global attention)。 | O(N·logN) |
FlashInfer 对 SparseAttention 的实现
FlashInfer 并未直接实现完整的 SparseAttention 算法(如 Longformer 或 BigBird),而是通过以下方式支持稀疏模式:
1. 与第三方库集成
FlashInfer 可与现有 SparseAttention 库(如 xformers
、torch-sparse
)结合使用,提供高效的稀疏矩阵计算内核:
import torch
import flashinfer
import xformers.ops as xops# 使用 xformers 的稀疏注意力 + FlashInfer 的 KV 缓存
def sparse_attention_with_flashinfer(q, k_cache, v_cache, mask):# q: [batch, seq_len, heads, dim]# k_cache, v_cache: FlashInfer 的 KV 缓存# 从 FlashInfer 缓存中获取 K, Vk, v = flashinfer.get_kv_from_cache(k_cache, v_cache, seq_lens)# 使用 xformers 的稀疏注意力计算attn_output = xops.memory_efficient_attention(q, k, v, attn_bias=xops.LowerTriangularMask() if causal else None,p=0.0 # dropout)return attn_output
2. 优化稀疏矩阵计算
FlashInfer 通过以下方式加速稀疏注意力的计算:
- GPU 内核优化:针对稀疏矩阵乘法(SpMM)生成专用 CUDA 内核,利用 GPU 并行计算能力。
- 内存访问优化:将稀疏模式预编码为高效的数据结构(如 CSR/CSC 格式),减少内存碎片。
- 与 FlashAttention 协同:对密集区域使用 FlashAttention 加速,稀疏区域使用优化的 SpMM。
3. 支持自定义稀疏模式
FlashInfer 允许用户通过 API 传入自定义掩码矩阵,灵活定义稀疏模式:
# 创建自定义稀疏掩码(示例:仅关注前 100 个 token)
sparse_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
sparse_mask[:, :100] = True # 每行仅关注前 100 列# 使用 FlashInfer 执行带掩码的注意力
output = flashinfer.attention_with_mask(q=query,k=key,v=value,mask=sparse_mask
)
4. 长序列优化
对于超长序列(如 16K+ tokens),FlashInfer 结合 分页 KV 缓存(PageAttention) 和稀疏模式:
- 将 KV 缓存按页存储,仅激活当前需要的页(类似内存分页机制)。
- 对活跃页内的 token 应用密集注意力,页间使用稀疏连接,大幅减少计算量。
性能对比
方法 | 计算复杂度 | 内存占用 | 适用场景 |
---|---|---|---|
标准注意力 | O(N²) | O(N²) | 短序列(N<2K) |
FlashAttention | O(N²) | O(N) | 中等序列(N<8K) |
SparseAttention + Flash | O(N·k) | O(N·k) | 长序列(N>10K) |