DeepSeek Sparse Attention(DSA)快速洞察(DeepSeek-V3.2)
DeepSeek-V3.2 技术报告中提到的 DeepSeek Sparse Attention (DSA) 是一种旨在提升上下文处理效率的稀疏注意力机制。
一、核心思想
传统 Transformer 模型的自注意力机制具有 O(L2)O(L^2)O(L2) 的计算复杂度(其中 LLL 是序列长度),这在处理长文本(如 128K tokens)时计算成本非常高。DSA的核心思想是:对于序列中的每一个 query token,并不需要关注所有之前的上文 token,而只需关注其中最关键的一小部分(k个)。通过这种方式,将核心注意力计算复杂度从 O(L2)O(L^2)O(L2) 降低到了 O(Lk)O(Lk)O(Lk)。
二、Multi-Head Latent Attention
1. Multi-Head Attention
在标准多头注意力(MHA)中,通过投影矩阵 WQ,WK,WV∈Rdhnh×dW^Q,W^K,W^V\in\R^{d_hn_h\times d}WQ,WK,WV∈Rdhnh×d 将输入的第 ttt 个 token ht∈Rd\mathbf{h}_t\in\R^dht∈Rd 转化为 query,key 和 value(qt=WQht,kt=WKht,vt=WVht,qt,kt,vt∈Rdhnh\mathbf{q}_t=W^Q\mathbf{h}_t,\mathbf{k}_t=W^K\mathbf{h}_t,\mathbf{v}_t=W^V\mathbf{h}_t,\mathbf{q}_t,\mathbf{k}_t,\mathbf{v}_t\in \R^{d_hn_h}qt=WQht,kt=WKht,vt=WVht,qt,kt,vt∈Rdhnh)。其中 ddd 是输入 token 的嵌入维度,维度 dhnhd_hn_hdhnh 表示在多头注意力机制中,qt,kt,vt\mathbf{q}_t,\mathbf{k}_t,\mathbf{v}_tqt,kt,vt 如何被切分为 nhn_hnh 个头,每个头的维度为 dhd_hdh。
[qt,1;qt,2;⋯ ;qt,nh]=qt,[kt,1;kt,2;⋯ ;kt,nh]=kt,[vt,1;vt,2;⋯ ;vt,nh]=vt,ot,i=∑j=1tSoftmaxj(qt,iTkj,idh)vj,i,u=WO[ot,1;ot,2;⋯ ;ot,nh]
\begin{align}
[\mathbf{q}_{t,1};\mathbf{q}_{t,2};\cdots;\mathbf{q}_{t,n_h}]=\mathbf{q}_t,\\
[\mathbf{k}_{t,1};\mathbf{k}_{t,2};\cdots;\mathbf{k}_{t,n_h}]=\mathbf{k}_t,\\
[\mathbf{v}_{t,1};\mathbf{v}_{t,2};\cdots;\mathbf{v}_{t,n_h}]=\mathbf{v}_t,\\
\mathbf{o}_{t,i}=\sum_{j=1}^t\text{Softmax}_j(\frac{\mathbf{q}_{t,i}^T\mathbf{k}_{j,i}}{\sqrt{d_h}})\mathbf{v}_{j,i},\\
\mathbf{u}=W^O[\mathbf{o}_{t,1};\mathbf{o}_{t,2};\cdots;\mathbf{o}_{t,n_h}]
\end{align}
[qt,1;qt,2;⋯;qt,nh]=qt,[kt,1;kt,2;⋯;kt,nh]=kt,[vt,1;vt,2;⋯;vt,nh]=vt,ot,i=j=1∑tSoftmaxj(dhqt,iTkj,i)vj,i,u=WO[ot,1;ot,2;⋯;ot,nh]
其中,qt,i,kt,i,vt,i∈Rdh\mathbf{q}_{t,i},\mathbf{k}_{t,i},\mathbf{v}_{t,i}\in\R^{d_h}qt,i,kt,i,vt,i∈Rdh 表示第 iii 个头的 QKV 值,WO∈Rd×dhnhW^O\in\R^{d\times d_hn_h}WO∈Rd×dhnh 是输出投影矩阵。在推理过程中,每一个 token 需要 KV cache 的大小为 2nhdhl2n_hd_hl2nhdhl,lll 注意力层数。
KV cache 是 Transformer 模型中 MHA 采用的一种推理加速技术,通过存储中间键值对避免重复计算。注意图中展示的多个头(query)都来源于同一个 token, 每一个头都对应有自己的 WiQ,WiK,WiVW_i^Q,W_i^K,W_i^VWiQ,WiK,WiV。MHA 会带来高昂的内存开销,成为系统瓶颈,减少 KV 的常见思路是将不同的查询对应相同的 KV。但这类方法在性能上始终与标准多头注意力存在差距。后来在 DeepSeek-V2 模型中提出的 MLA 实现了突破,该创新方案在显著减少 KV 缓存需求的同时,反而获得了更优越的性能表现。
2. Low-Rank Key-Value Joint Compression
MLA 的核心思想是将投影矩阵分解为两个低秩矩阵:W=WUWDKVW=W^UW^{DKV}W=WUWDKV,其中 WDKV∈Rdc×dW^{DKV}\in\R^{d_c\times d}WDKV∈Rdc×d 是键和值的下投影矩阵(down-projection matrix),WU∈Rdhnh×dcW^U\in\R^{d_hn_h\times d_c}WU∈Rdhnh×dc 是上投影矩阵(up-projection matrix),且 dc≪dhnhd_c\ll d_hn_hdc≪dhnh。下投影矩阵将键和值压缩进一个隐向量 ctKV=WDKVht,ctKV∈Rdc\mathbf{c}_t^{KV}=W^{DKV}\mathbf{h}_t,\mathbf{c}_t^{KV}\in\R^{d_c}ctKV=WDKVht,ctKV∈Rdc。因为对每一个 token 而言只需要储存 ctKV\mathbf{c}_t^{KV}ctKV 而不用储存 kt\mathbf{k}_tkt 和 vt\mathbf{v}_tvt,所以需要的内存开销从 2nhdhl2n_hd_hl2nhdhl 减少到 dcld_cldcl。而键和值通过隐向量 ctKV\mathbf{c}_t^{KV}ctKV 计算得到:
ktC=WUKctKV,vtC=WUVctKV,
\mathbf{k}_t^C=W^{UK}\mathbf{c}_t^{KV},\\
\mathbf{v}_t^C=W^{UV}\mathbf{c}_t^{KV},
ktC=WUKctKV,vtC=WUVctKV,
其中,WUK,WUV∈Rdhnh×dcW^{UK},W^{UV}\in\R^{d_hn_h\times d_c}WUK,WUV∈Rdhnh×dc 分别表示键和值的上投影矩阵。关键之处在于,在推理过程中 WUKW^{UK}WUK 被吸收进 WQW^QWQ,而 WUVW^{UV}WUV 被吸收进 WOW^OWO,因此我们无需显示地计算 ktC,vtC\mathbf{k}_t^C,\mathbf{v}_t^CktC,vtC。这里吸收的意思是:在推理前,预先将多个投影矩阵相乘合并,从而避免推理过程中显示生成中间变量,但这并不意味着训练时候只训练一个矩阵,这些投影矩阵在训练时依旧是分开训练的。
qiTkj=(WQhi)TWUKcjkv=hiT((WQ)TWUK)cjkv=hiTAcjkv
\begin{align}
\mathbf{q}_i^T\mathbf{k}_j&=(W^Q\mathbf{h}_i)^TW^{UK}\mathbf{c}_j^{kv}\\
&=\mathbf{h}_i^T((W^Q)^TW^{UK})\mathbf{c}_j^{kv}\\
&=\mathbf{h}_i^TA\mathbf{c}_j^{kv}
\end{align}
qiTkj=(WQhi)TWUKcjkv=hiT((WQ)TWUK)cjkv=hiTAcjkv
此外,在训练过程中还采用了查询的低秩压缩技术,以降低激活内存的使用。
ctQ=WDQht,qtC=WUQctQ.
\mathbf{c}_t^Q=W^{DQ}\mathbf{h}_t,\\
\mathbf{q}_t^C=W^{UQ}\mathbf{c}_t^Q.
ctQ=WDQht,qtC=WUQctQ.
3. 解耦 RoPE
DeepSeek-V2 采用旋转位置编码(RoPE):
qiTkj=(WQhi)TRoPEΘ,j−i(WKhj)=hiT(WQ)TRoPEΘ,j−i(WUKWDKVhj)
\begin{align}
\mathbf{q}_i^T\mathbf{k}_j&=(W^Q\mathbf{h}_i)^T\text{RoPE}_{\Theta,j-i}(W^K\mathbf{h}_j)\\
&=\mathbf{h}_i^T(W^Q)^T\text{RoPE}_{\Theta,j-i}(W^{UK}W^{DKV}\mathbf{h}_j)
\end{align}
qiTkj=(WQhi)TRoPEΘ,j−i(WKhj)=hiT(WQ)TRoPEΘ,j−i(WUKWDKVhj)
其中,RoPEΘ,j−i(⋅)\text{RoPE}_{\Theta,j-i}(\cdot)RoPEΘ,j−i(⋅) 表示应用 RoPE\text{RoPE}RoPE 的运算操作,Θ\ThetaΘ 是预定义的参数,而 i,ji,ji,j 分别表示第 iii 和第 jjj 个位置。它作用在 WUKWDKVhjW^{UK}W^{DKV}\mathbf{h}_jWUKWDKVhj 这个整体上,且依赖于 i,ji,ji,j,因此,WUKW^{UK}WUK 不会被吸收进 WQW^QWQ 中,这会导致推理过程中的计算成本显著增加。
为了解决这个问题,DeepSeek-V2 提出将 RoPE 解耦为一组独立的 query 和 key:多头查询 qt,iR∈RdhR\mathbf{q}_{t,i}^R\in\R^{d_h^R}qt,iR∈RdhR 和所有头共享的键 ktR∈RdhR\mathbf{k}_t^R\in\R^{d_h^R}ktR∈RdhR,其中 dhRd_h^RdhR 表示解耦后查询与键的每头维度。这种解耦策略本质上会计算两组独立的注意力权重,随后将其相加。完整的多头潜在注意力(MLA)计算流程如下:
[qt,1R;qt,2R;⋯ ;qt,nhR]=qtR=RoPE(WQRctQ),ktR=RoPE(WKRht),qt,i=[qt,iC;qt,iR],kt,i=[kt,iC;ktR],ot,i=∑j=1tSoftmaxj(qt,iTkj,idh+dhR)vj,iC,ut=WO[ot,1;ot,2;⋯ ;ot,nh],
\begin{align}
[\mathbf{q}_{t,1}^R;\mathbf{q}_{t,2}^R;\cdots;\mathbf{q}_{t,n_h}^R]=\mathbf{q}_t^R&=\text{RoPE}(W^{QR}\mathbf{c}_t^Q),\\
\mathbf{k}_t^R&=\text{RoPE}(W^{KR}\mathbf{h}_t),\\
\mathbf{q}_{t,i}&=[\mathbf{q}_{t,i}^C;\mathbf{q}_{t,i}^R],\\
\mathbf{k}_{t,i}&=[\mathbf{k}_{t,i}^C;\mathbf{k}_t^R],\\
\mathbf{o}_{t,i}&=\sum_{j=1}^t\text{Softmax}_j(\frac{\mathbf{q}_{t,i}^T\mathbf{k}_{j,i}}{\sqrt{d_h+d_h^R}})\mathbf{v}_{j,i}^C,\\
\mathbf{u}_t&=W^O[\mathbf{o}_{t,1};\mathbf{o}_{t,2};\cdots;\mathbf{o}_{t,n_h}],
\end{align}
[qt,1R;qt,2R;⋯;qt,nhR]=qtRktRqt,ikt,iot,iut=RoPE(WQRctQ),=RoPE(WKRht),=[qt,iC;qt,iR],=[kt,iC;ktR],=j=1∑tSoftmaxj(dh+dhRqt,iTkj,i)vj,iC,=WO[ot,1;ot,2;⋯;ot,nh],
其中 WQR∈RdhRnh×dc′,WKR∈RdhR×dW^{QR}\in\R^{d_h^Rn_h\times d_c'},W^{KR}\in\R^{d_h^R\times d}WQR∈RdhRnh×dc′,WKR∈RdhR×d 分别表示用于生成解耦查询和键的矩阵,[⋅;⋅][\cdot;\cdot][⋅;⋅] 表示矩阵拼接操作。在推理过程中,维度为 dhRd_h^RdhR 的解耦键 ktR\mathbf{k}_t^RktR 也会被缓存。因此,每个 token 总共需要缓存大小 (dc+dhR)l(d_c+d_h^R)l(dc+dhR)l 的空间。对于 DeepSeek-V2 模型,dc=4dh,dhR=dh2d_c=4d_h,d_h^R=\frac{d_h}{2}dc=4dh,dhR=2dh,因此每个 token 所需的 KV 缓存大小为 92dhl\frac{9}{2}d_hl29dhl。
二、DSA 的两大核心组件
- 快速索引器(Lightning Indexer)
- 细粒度 token 选择机制(Fine-grained Token Selection Mechanism)
其工作原理可以概括为:对于每一个查询 token(ht∈Rd\mathbf{h}_t \in \R^dht∈Rd),先用快速索引器迅速地为每一个它的前继 token(hs∈Rd\mathbf{h}_s\in\R^dhs∈Rd)计算一个索引分数(It,sI_{t,s}It,s),然后根据这个分数,通过细粒度 token 选择机制选出最重要的 kkk 个 token,最后只在这 kkk 个 token 上执行标准的注意力计算。
下面我们详细拆解这两个组件。
1. 快速索引器(Lightning Indexer)
这个组件的作用是快速、粗略地评估一个查询 token ht\mathbf{h}_tht 与它之前的每一个 token hs\mathbf{h}_shs 之间的相关性。计算公式如下:
It,s=∑j=1HIwt,jI⋅ReLU(qt,jI⋅ksI)
I_{t,s}=\sum_{j=1}^{H^I}w_{t,j}^I\cdot \text{ReLU}(\mathbf{q}_{t,j}^I\cdot\mathbf{k}_{s}^I)
It,s=j=1∑HIwt,jI⋅ReLU(qt,jI⋅ksI)
其中:
-
j=1,2⋯ ,HIj=1,2\cdots,H^Ij=1,2⋯,HI:jjj 表示索引器头的编号,HIH^IHI 表示索引器的头数,上标 III 用于区分属于索引器的参数和向量和属于主注意力模型(MLA)的参数和向量。
-
qtI∈RdIHI\mathbf{q}_t^I\in\R^{d^IH^I}qtI∈RdIHI 和 ksI∈RdI\mathbf{k}_s^I\in\R^{d^I}ksI∈RdI:索引器专用的查询向量和键向量,分别由 ht\mathbf{h}_tht 和 hs\mathbf{h}_shs 通过某种变换得到的,它们的维度 dId^IdI 不需要与主注意力模型中的查询/键向量维度相同。
ctQ=WDQht,qtI=WUQIctQ=[qt,1I;qt,2I;⋯ ;qt,HII]ksI=WKIhs \begin{align} \mathbf{c}_t^Q&=W^{DQ}\mathbf{h}_t,\\ \mathbf{q}_t^I&=W^{UQI}\mathbf{c}_t^Q=[\mathbf{q}_{t,1}^I;\mathbf{q}_{t,2}^I;\cdots;\mathbf{q}_{t,H^I}^I]\\ \mathbf{k}_s^I&=W^{KI}\mathbf{h}_s \end{align} ctQqtIksI=WDQht,=WUQIctQ=[qt,1I;qt,2I;⋯;qt,HII]=WKIhs
- wtI∈RHIw_t^I\in\R^{H^I}wtI∈RHI:权重向量,同样由 ht\mathbf{h}_tht 衍生而来,用于对不同索引器头的输出进行加权求和。
wtI=WwIht w_t^I=W^{wI}\mathbf{h}_t wtI=WwIht
出于对计算效率的考量,选用 ReLU\text{ReLU}ReLU 作为激活函数。鉴于快速索引器具有较少的头数且支持 FP8 精度实现,其运算效率表现尤为出色。
2. 细粒度 Token 选择机制(Fine-grained Token Selection Mechanism)
这个组件利用索引器的输出,执行真正的“稀疏”选择。
工作流程:
- 对于 ht\mathbf{h}_tht,快速索引器为其计算出了其与所有前继 token hs\mathbf{h}_shs 的索引分数集合 {It,s}\{I_{t,s}\}{It,s}。
- 从这个集合中,只选出分数最高的 kkk 个 token。即:{hs∣It,s∈Top-k(It,:)}\{\mathbf{h}_s|I_{t,s}\in \text{Top-k}(I_{t,:})\}{hs∣It,s∈Top-k(It,:)}。
- 从 Key-Value 缓存中,只取出这 kkk 个被选中的 token 对应的 Key-value 条目,记为 {csKV}\{c_{s}^{KV}\}{csKV}。
- 最终的注意力输出 ut=Attn(ht,{csKV∣It,s∈Top-k(It,:)})u_t=\text{Attn}(\mathbf{h}_t,\{c_s^{KV}|I_{t,s}\in \text{Top-k}(I_t,:)\})ut=Attn(ht,{csKV∣It,s∈Top-k(It,:)})。