当前位置: 首页 > news >正文

FlashInfer - SparseAttention(稀疏注意力)只计算部分有意义的注意力连接,而非全部 token 对

FlashInfer - SparseAttention(稀疏注意力)只计算部分有意义的注意力连接,而非全部 token 对

flyfish

SparseAttention 原理

SparseAttention(稀疏注意力)是针对标准自注意力机制的优化,核心目标是将 O(N²) 的计算复杂度降低到 O(N·k)(k 为稀疏度),特别适用于长序列(如 N>10k)。其核心思想是:只计算部分有意义的注意力连接,而非全部 token 对

在算法复杂度分析中,线性(Linear)亚线性(Sublinear) 是描述算法效率随输入规模增长的术语,与平方级(Quadratic)增长形成对比。

1. 核心定义

线性复杂度(O(N))
  • 定义:算法的计算量或内存开销与输入规模 N N N正比例关系
  • 特点:当 N N N 翻倍时,计算量/内存也翻倍。
  • 示例
    • 遍历数组求和:每个元素仅访问一次,总操作数为 N N N
    • 稀疏注意力的局部窗口模式(每个token仅关注前后 k k k 个token):总操作数为 N ⋅ k N \cdot k Nk,当 k k k 固定时,复杂度为 O ( N ) O(N) O(N)
亚线性复杂度(O(N^α),其中 α < 1)
  • 定义:算法的计算量或内存开销增长慢于线性
  • 特点:当 N N N 翻倍时,计算量/内存增长小于两倍。
  • 常见形式
    • O ( N ) O(\sqrt{N}) O(N ):如分块算法。
    • O ( N log ⁡ N ) O(N \log N) O(NlogN):如快速排序、某些稀疏注意力的动态选择策略。

2. 与平方级复杂度(O(N²))的对比

复杂度含义增长速度(N增大时)适用场景
O(N²)平方级:计算量与 N 2 N^2 N2 成正比极快(N翻倍时计算量×4)短序列(N<1000)
O(N)线性:计算量与 N N N 成正比中等(N翻倍时计算量×2)长序列(N>10000)
O(N^α)亚线性:α<1(如O(√N)、O(N logN))最慢(N翻倍时增长<2倍)超长序列(N>1M)或特殊场景

3. 稀疏注意力中的线性/亚线性优化

(1)局部窗口稀疏(Linear)
  • 模式:每个token仅关注前后 k k k 个token(如 k = 512 k=512 k=512)。
  • 复杂度:总操作数 N ⋅ k N \cdot k Nk,当 k k k 固定时为 O ( N ) O(N) O(N)
    计算量仅与序列长度 N N N 线性相关,而非 N 2 N^2 N2
(2)动态Top-K选择(Sublinear)
  • 模式:每个token仅关注注意力得分最高的 K K K 个token(如 K = 256 K=256 K=256)。
  • 复杂度:若使用高效选择算法(如堆排序),总操作数为 O ( N log ⁡ K ) ≈ O ( N log ⁡ N ) O(N \log K) \approx O(N \log N) O(NlogK)O(NlogN)(亚线性)。
    仅计算关键连接,避免冗余计算。

4. FlashInfer中的线性优化

FlashInfer通过以下方式实现线性或近似线性复杂度:

  1. 分页KV缓存(PageAttention)

    • 将KV缓存分成固定大小的页(如每页1024个token),仅激活当前相关的页。
    • 计算复杂度从 O ( N 2 ) O(N^2) O(N2) 降至 O ( N ⋅ p ) O(N \cdot p) O(Np) p p p 为活跃页数,通常 p ≪ N p \ll N pN)。
  2. 与稀疏注意力库集成

    • 结合 xformers 等库,支持局部窗口、扩张窗口等稀疏模式,直接将计算量降至 O ( N ) O(N) O(N)
  3. 内存访问优化

    • 通过连续内存布局和批量操作,减少每步计算的开销,进一步提升线性算法的效率。

5. 为什么线性/亚线性很重要?

在长序列场景(如 N = 100 K N=100K N=100K)中,平方级算法(如标准注意力)的计算和内存需求会爆炸式增长,而线性/亚线性算法仍能保持可扩展性:

序列长度 N N NO(N²) 计算量O(N) 计算量加速比
1,0001,000,0001,0001,000x
10,000100,000,00010,00010,000x
100,00010,000,000,000100,000100,000x

可见,当 N N N 增大时,线性算法的优势愈发明显。这也是 FlashInfer 等框架支持稀疏注意力的核心原因——突破平方级瓶颈,让模型处理超长文本成为可能

计算复杂度和内存开销呈平方级增长

在自注意力机制(Self-Attention)中,当序列长度 N N N 增大时,计算复杂度内存开销呈平方级增长( O ( N 2 ) O(N^2) O(N2)),根本原因在于 注意力矩阵的规模是 N × N N \times N N×N

一、自注意力的核心计算步骤

自注意力的计算分为三步(以单头注意力为例):

  1. 计算查询-键矩阵(QK矩阵)
    Score = Q K T 其中 Q , K , V ∈ R N × d \text{Score} = QK^T \quad \text{其中} \quad Q, K, V \in \mathbb{R}^{N \times d} Score=QKT其中Q,K,VRN×d
    Q Q Q(查询)和 K K K(键)的形状均为 N × d N \times d N×d d d d 是特征维度),它们的矩阵乘积得到 注意力分数矩阵 Score ∈ R N × N \text{Score} \in \mathbb{R}^{N \times N} ScoreRN×N

    • 计算复杂度:每个元素的计算需要 d d d 次乘法,总共有 N 2 N^2 N2 个元素,复杂度为 O ( N 2 d ) O(N^2 d) O(N2d)
    • 内存开销:存储 Score \text{Score} Score 需要 N 2 N^2 N2 个浮点数(如FP32时占 4 N 2 4N^2 4N2 字节)。
  2. Softmax归一化和值加权(QKV加权)
    Attention = Softmax ( Score ) ⋅ V \text{Attention} = \text{Softmax}(\text{Score}) \cdot V Attention=Softmax(Score)V

    • N × N N \times N N×N 的矩阵做Softmax,复杂度为 O ( N 2 ) O(N^2) O(N2)
    • 矩阵与 V ∈ R N × d V \in \mathbb{R}^{N \times d} VRN×d 相乘,得到 N × d N \times d N×d 的输出,复杂度为 O ( N 2 d ) O(N^2 d) O(N2d)
  3. 多头注意力扩展
    若有 h h h 个头,每个头独立计算,总复杂度变为 O ( h N 2 d ) O(hN^2 d) O(hN2d),但 h h h d d d 通常是固定维度(如 h = 16 , d = 64 h=16, d=64 h=16,d=64),因此主导项仍是 O ( N 2 ) O(N^2) O(N2)

二、平方级增长的本质原因

1. 计算复杂度的平方级来源
  • 核心操作是矩阵乘法 Q K T QK^T QKT
    两个 N × d N \times d N×d 矩阵相乘得到 N × N N \times N N×N 矩阵,计算量为 N × d × N = N 2 d N \times d \times N = N^2 d N×d×N=N2d
    N N N 增大时, N 2 N^2 N2 是主导项(即使 d d d 固定,如 d = 1024 d=1024 d=1024 N = 4096 N=4096 N=4096 N 2 = 16 , 777 , 216 N^2 = 16,777,216 N2=16,777,216,远大于 N N N)。

  • 每对token都需计算一次关联
    自注意力要求每个token(共 N N N 个)与所有其他token(包括自身,共 N N N 个)计算注意力分数,总共有 N × N = N 2 N \times N = N^2 N×N=N2 次两两交互。

2. 内存开销的平方级来源
  • 存储注意力分数矩阵( N × N N \times N N×N):
    假设使用FP32(4字节/数),存储该矩阵需要 4 N 2 4N^2 4N2 字节。

    • N = 1024 N=1024 N=1024 时,需要约4MB;
    • N = 16 K N=16K N=16K 时,需要约1GB;
    • N = 32 K N=32K N=32K 时,需要约4GB(仅单个注意力头)。
      这还不包括中间变量(如Softmax结果、梯度等),实际内存占用会更高。
  • 隐含的内存放大效应
    矩阵运算(如PyTorch/TensorFlow的底层实现)需要额外的临时存储空间,进一步加剧内存压力。

三、举例说明

假设 d = 1024 d=1024 d=1024(固定维度),比较不同 N N N 时的计算量和内存:

序列长度 N N N注意力矩阵大小计算量(乘加操作数)内存占用(FP32,单头)
10241024×1024~1e9次~4MB
20482048×2048~4e9次(×4)~16MB(×4)
40964096×4096~16e9次(×16)~64MB(×16)
16K16K×16K~262e9次(×256)~1GB(×256)
32K32K×32K~1e12次(×1024)~4GB(×1024)

可见, N N N 翻倍时,计算量和内存占用均变为原来的4倍(平方级增长)

四、为什么稀疏注意力能缓解这个问题?

稀疏注意力(如SparseAttention、FlashInfer支持的模式)通过 减少有效交互的token对数量(从 N 2 N^2 N2 降至 O ( N ⋅ r ) O(N \cdot r) O(Nr) r r r 是每个token关注的平均token数),将复杂度优化为线性或亚线性(如 O ( N r ) O(Nr) O(Nr) r ≪ N r \ll N rN)。例如:

  • 局部窗口稀疏:每个token仅关注附近 r = 512 r=512 r=512 个token,总交互数为 N ⋅ 512 = O ( N ) N \cdot 512 = O(N) N512=O(N)
  • Top-K稀疏:每个token仅关注得分最高的 K = 256 K=256 K=256 个token,总交互数为 N ⋅ K = O ( N K ) N \cdot K = O(NK) NK=O(NK)

有意义的注意力连接

在注意力机制中,“有意义的注意力连接” 指的是 对当前任务或输入内容有实际影响的 token 对
“token 对”(Token Pair) 指的是 序列中任意两个 token 之间的交互关系。
“有意义的注意力连接”是一个与语义、句法、任务目标相关的概念,不同模型或场景可能采用不同的标准。在实际实现中,通常通过稀疏模式设计(如局部窗口、动态采样)或注意力得分过滤(如 Top-K、Top-P)来近似定义和计算这些有意义的连接,从而在保持模型性能的同时降低计算复杂度。

1. 基于语义关联的“有意义”

  • 定义:两个 token 在语义上存在强关联,关注它们的关系能帮助模型更好地理解上下文。
  • 示例
    • 输入句子:“济南是山东的省会,也是一座历史悠久的城市。”
    • 注意力机制应关注“济南”与“城市”之间的连接,因为它们在语义上高度相关。
    • 而“济南”与句首的“的”或句尾的句号之间的连接则可能被视为“无意义”。

2. 基于句法结构的“有意义”

  • 定义:两个 token 在句法上存在依赖关系(如主语-谓语、动词-宾语)。
  • 示例
    • 句子:“我喜欢苹果。”
    • “喜欢”与“苹果”之间的连接是有意义的,因为它们构成动宾关系。
    • 而“我”与“苹果”之间的直接连接可能较弱(需通过“喜欢”间接关联)。

3. 基于任务目标的“有意义”

  • 定义:根据具体任务,某些 token 对的关系对预测结果更重要。
  • 示例
    • 机器翻译:源语言与目标语言中对应的词对(如“apple”→“苹果”)。
    • 问答系统:问题中的关键词与上下文中的答案候选。
    • 命名实体识别:实体词与描述其属性的词(如“微软”与“公司”)。

4. 基于注意力得分的“有意义”

  • 定义:通过模型计算出的注意力得分(如 softmax 后的权重)筛选出重要连接。
  • 具体方法
    • Top-K 选择:保留得分最高的 K 个连接。
    • 阈值过滤:丢弃得分低于阈值的连接(如仅保留得分 > 0.1 的连接)。
    • 动态稀疏化:根据得分分布自适应确定稀疏度(如 Top-P 采样)。

5. 常见稀疏模式中的“有意义”设计

稀疏模式如何定义“有意义”
局部窗口认为当前 token 附近的 token 更相关(如前后 512 个 token)。
扩张窗口保留局部相关性的同时,通过跳跃式采样捕获长距离依赖(类似 CNN 中的空洞卷积)。
基于内容的稀疏使用额外的网络或启发式规则动态识别相关 token(如 Longformer 的 global attention)。
块稀疏将序列分块,块内全连接,块间仅保留关键连接(如 BigBird 的 block sparse)。

6. FlashInfer 中的“有意义”实现

FlashInfer 本身不直接定义“有意义”,而是通过以下方式支持用户自定义或第三方库定义的稀疏模式:

  1. 灵活的掩码接口:允许用户传入自定义掩码矩阵,显式指定哪些连接需要计算。
    # 示例:仅关注前 100 个 token
    mask = torch.zeros(seq_len, seq_len)
    mask[:, :100] = 1  # 每行仅关注前 100 列
    output = flashinfer.attention_with_mask(q, k, v, mask)
    
  2. 与第三方稀疏注意力库集成:如 xformerstorch-sparse,利用它们的算法识别有意义的连接。
  3. 长序列优化:通过分页 KV 缓存(PageAttention),仅激活当前可能相关的 token 页。

SparseAttention 原理

SparseAttention(稀疏注意力)是针对标准自注意力机制的优化,核心目标是将 O(N²) 的计算复杂度降低到 O(N·k)(k 为稀疏度),特别适用于长序列(如 N>10k)。其核心思想是:只计算部分有意义的注意力连接,而非全部 token 对

稀疏矩阵的定义与判定

1. 定义与核心特征

稀疏矩阵(Sparse Matrix)是指矩阵中绝大多数元素为零,仅有少量非零元素的矩阵。其核心特征是:

  • 零元素占比极高,非零元素占比极低。
  • 没有严格的数学阈值(如“必须超过X%的零元素”),但实际应用中,通常认为零元素占比超过 50% 即可视为稀疏矩阵,具体标准因领域而异(如数值计算、机器学习中可能要求更高的稀疏度,如70%以上)。
2. 分析

矩阵:

  • 非零元素:9个

  • 零元素:26个

  • 总元素:9 + 26 = 35个

  • 稀疏度(零元素占比): 26 35 ≈ 74.29 % \frac{26}{35} \approx 74.29\% 352674.29%

  • 密度(非零元素占比): 9 35 ≈ 25.71 % \frac{9}{35} \approx 25.71\% 35925.71%

  • 属于稀疏矩阵。74%的零元素占比远超过“大部分”的基本要求(超过50%),符合稀疏矩阵的典型特征。

  • “绝大部分元素为零”中的“绝大部分”通常指超过一半(>50%),具体数值依场景而定。例如:

    • 若矩阵为 $5 \times 7 = 35 ) 维,26个零元素(占比74%)显然属于“绝大部分”;
    • 若矩阵规模更大(如 $100 \times 100 )),即使零元素占比60%,也可能被视为稀疏矩阵。
3. 关键补充
  • 无严格统一标准:稀疏矩阵的判定是相对的,取决于应用场景。例如:
    • 在科学计算中,零元素占比超过70%常被视为稀疏矩阵;
    • 在某些极端场景(如社交网络邻接矩阵),零元素占比可能高达99%以上。
  • 稀疏度与密度的计算
    • 稀疏度 = 零元素数量 总元素数量 × 100 % \frac{\text{零元素数量}}{\text{总元素数量}} \times 100\% 总元素数量零元素数量×100%
    • 密度 = 1 − 稀疏度 = 非零元素数量 总元素数量 × 100 % 1 - \text{稀疏度} = \frac{\text{非零元素数量}}{\text{总元素数量}} \times 100\% 1稀疏度=总元素数量非零元素数量×100%

稀疏度需结合矩阵规模和领域需求综合判断

1. 标准注意力的瓶颈

标准自注意力计算 Attention ( Q , K , V ) = softmax ( Q K T d ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V Attention(Q,K,V)=softmax(d QKT)V 需要计算所有 token 对之间的注意力得分,当序列长度 N N N 增大时,计算和内存开销呈平方级增长。

2. SparseAttention 的核心思路

通过设计稀疏模式(Sparsity Pattern),选择性地忽略部分 token 对之间的注意力计算:

  • 固定模式:如局部窗口(只关注当前 token 前后的 k 个 token)、扩张窗口(类似 CNN 中的空洞卷积)。
  • 动态模式:根据输入内容动态选择重要的 token 对(如基于注意力得分阈值)。
  • 结构化稀疏:利用矩阵分解或低秩近似减少计算量。
3. 数学表达

标准注意力的得分矩阵 S = Q K T S = QK^T S=QKT 是稠密的,而稀疏注意力只计算其中非零元素:
S i , j = { q i T k j if  ( i , j ) ∈ 稀疏模式 0 otherwise S_{i,j} = \begin{cases} q_i^T k_j & \text{if } (i,j) \in \text{稀疏模式} \\ 0 & \text{otherwise} \end{cases} Si,j={qiTkj0if (i,j)稀疏模式otherwise
其中,稀疏模式可通过掩码矩阵 M M M 表示:
S sparse = S ⊙ M S_{\text{sparse}} = S \odot M Ssparse=SM
⊙ \odot 表示逐元素乘法)。

4. 典型稀疏模式

模式描述复杂度
局部窗口每个 token 只关注前后固定窗口内的 token(如窗口大小=512)。O(N·k)
扩张窗口窗口内每隔一定步长采样 token,扩大感受野(类似 CNN 中的空洞卷积)。O(N·k/d)
随机稀疏随机选择固定比例的 token 对进行计算。O(N²·p)
基于内容的稀疏根据注意力得分动态选择 top-k 个连接(如 Longformer 的 global attention)。O(N·logN)

FlashInfer 对 SparseAttention 的实现

FlashInfer 并未直接实现完整的 SparseAttention 算法(如 Longformer 或 BigBird),而是通过以下方式支持稀疏模式

1. 与第三方库集成

FlashInfer 可与现有 SparseAttention 库(如 xformerstorch-sparse)结合使用,提供高效的稀疏矩阵计算内核:

import torch
import flashinfer
import xformers.ops as xops# 使用 xformers 的稀疏注意力 + FlashInfer 的 KV 缓存
def sparse_attention_with_flashinfer(q, k_cache, v_cache, mask):# q: [batch, seq_len, heads, dim]# k_cache, v_cache: FlashInfer 的 KV 缓存# 从 FlashInfer 缓存中获取 K, Vk, v = flashinfer.get_kv_from_cache(k_cache, v_cache, seq_lens)# 使用 xformers 的稀疏注意力计算attn_output = xops.memory_efficient_attention(q, k, v, attn_bias=xops.LowerTriangularMask() if causal else None,p=0.0  # dropout)return attn_output
2. 优化稀疏矩阵计算

FlashInfer 通过以下方式加速稀疏注意力的计算:

  • GPU 内核优化:针对稀疏矩阵乘法(SpMM)生成专用 CUDA 内核,利用 GPU 并行计算能力。
  • 内存访问优化:将稀疏模式预编码为高效的数据结构(如 CSR/CSC 格式),减少内存碎片。
  • 与 FlashAttention 协同:对密集区域使用 FlashAttention 加速,稀疏区域使用优化的 SpMM。
3. 支持自定义稀疏模式

FlashInfer 允许用户通过 API 传入自定义掩码矩阵,灵活定义稀疏模式:

# 创建自定义稀疏掩码(示例:仅关注前 100 个 token)
sparse_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
sparse_mask[:, :100] = True  # 每行仅关注前 100 列# 使用 FlashInfer 执行带掩码的注意力
output = flashinfer.attention_with_mask(q=query,k=key,v=value,mask=sparse_mask
)
4. 长序列优化

对于超长序列(如 16K+ tokens),FlashInfer 结合 分页 KV 缓存(PageAttention) 和稀疏模式:

  • 将 KV 缓存按页存储,仅激活当前需要的页(类似内存分页机制)。
  • 对活跃页内的 token 应用密集注意力,页间使用稀疏连接,大幅减少计算量。

性能对比

方法计算复杂度内存占用适用场景
标准注意力O(N²)O(N²)短序列(N<2K)
FlashAttentionO(N²)O(N)中等序列(N<8K)
SparseAttention + FlashO(N·k)O(N·k)长序列(N>10K)

相关文章:

  • 文件(文件夹时间戳修改)最后修改时间变更
  • python打卡day25@浙大疏锦行
  • promise的说明
  • Minimum MPDU Start Spacing in A-MPDU
  • Spring Cloud:构建云原生微服务架构的最佳工具和实践
  • WhaleTunnel 信创数据库适配能力全景图:打通国产数据生态的最后一公里
  • 【Linux】shell内置命令fg,bg和jobs
  • 缺乏自动化测试,如何提高测试效率
  • 剖析提示词工程中的递归提示
  • Dockerfile实战:从零构建自定义CentOS镜像
  • UOS专业版上通过源码安装 Python 3.13 并保留系统默认版本
  • 关于并发编程AQS的学习
  • Python 之 Flask 入门学习
  • 计算机图形学之几何(Geometry)
  • Spring 事件监听机制的使用
  • Spring 中的 @Configuration @Bean注解
  • UE5 像素推流
  • 在UI 原型设计中,交互规则有哪些核心要素?
  • 数值积分知识
  • 【嵌入模型与向量数据库】
  • 俄方代表团抵达土耳其,俄乌直接谈判有望于当地时间上午重启
  • 知名猎头公司创始人兼首席执行官庄华因突发疾病逝世,享年62岁
  • 日本前卫艺术先驱群展上海:当具体派相遇古树古宅
  • 第十届影像上海博览会落幕后,留给中国摄影收藏的三个问题
  • 秦洪看盘|预期改善,或迎来新的增量资金
  • 普京提议无条件重启俄乌谈判,外交部:我们支持一切致力于和平的努力