Transformer多头注意力机制
多头注意力机制:多角度捕捉信息的直观解析
多头注意力机制是Transformer架构的核心创新,而多头注意力(Multi-Head Attention)则是对基础自注意力机制的重要扩展。它通过并行计算多个"注意力头"(head),让模型能够从不同角度理解序列数据中的关联关系,极大提升了模型的表达能力。
为什么需要多头注意力?
想象你正在阅读一篇文章,理解内容时你会同时关注多个方面:
- 词语之间的语法关系(主谓宾结构)
- 概念之间的语义关联(同义词、上下位词)
- 句子之间的逻辑衔接(因果、转折)
单一的注意力机制就像只用一种视角分析文本,而多头注意力则相当于让多个"专家"同时工作,每个专家专注于捕捉一种特定类型的关系,最后汇总所有专家的见解。
多头注意力的工作原理
1. 分头处理:多个视角并行计算
首先,我们将原始的查询(Q)、键(K)和值(V)通过不同的线性变换矩阵,投影到多个子空间中。每个子空间对应一个"注意力头",独立计算注意力。
对于第 iii 个头(i=1,2,...,hi = 1, 2, ..., hi=1,2,...,h):
headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)headi=Attention(QWiQ,KWiK,VWiV)
其中:
- WiQ,WiK,WiVW_i^Q, W_i^K, W_i^VWiQ,WiK,WiV 是第 iii 个头特有的可学习权重矩阵
- Attention\text{Attention}Attention 是基础的自注意力计算函数:Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V
实例说明:分析句子"猫 吃 鱼"
假设我们有3个注意力头,分别关注:
- 头1:语法关系(谁是动作的执行者)
- 头2:语义关联(动作与对象的匹配度)
- 头3:位置邻近关系(相邻词之间的联系)
原始输入矩阵(简化为2维向量):
Q=[q1q2q3]=[100100],K=[k1k2k3]=[100100],V=[v1v2v3]=[100100]Q = \begin{bmatrix} q_1 \\ q_2 \\ q_3 \end{bmatrix} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0 & 0 \end{bmatrix}, \quad K = \begin{bmatrix} k_1 \\ k_2 \\ k_3 \end{bmatrix} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0 & 0 \end{bmatrix}, \quad V = \begin{bmatrix} v_1 \\ v_2 \\ v_3 \end{bmatrix} = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0 & 0 \end{bmatrix}Q=q1q2q3=100010,K=k1k2k3=100010,V=v1v2v3=100010
每个头通过不同的投影矩阵处理:
- 头1投影矩阵:W1Q=W1K=W1V=[1000]W_1^Q = W_1^K = W_1^V = \begin{bmatrix} 1 & 0 \\ 0 & 0 \end{bmatrix}W1Q=W1K=W1V=[1000]
- 头2投影矩阵:W2Q=W2K=W2V=[0001]W_2^Q = W_2^K = W_2^V = \begin{bmatrix} 0 & 0 \\ 0 & 1 \end{bmatrix}W2Q=W2K=W2V=[0001]
- 头3投影矩阵:W3Q=W3K=W3V=[1001]W_3^Q = W_3^K = W_3^V = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}W3Q=W3K=W3V=[1001]
计算第1个头的注意力:
QW1Q=[100000],KW1K=[100000]Q W_1^Q = \begin{bmatrix} 1 & 0 \\ 0 & 0 \\ 0 & 0 \end{bmatrix}, \quad K W_1^K = \begin{bmatrix} 1 & 0 \\ 0 & 0 \\ 0 & 0 \end{bmatrix}QW1Q=100000,KW1K=100000
scores1=(QW1Q)(KW1K)Tdk=[100000000]\text{scores}_1 = \frac{(Q W_1^Q)(K W_1^K)^T}{\sqrt{d_k}} = \begin{bmatrix} 1 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & 0 \end{bmatrix}scores1=dk(QW1Q)(KW1K)T=100000000
head1=softmax(scores1)⋅(VW1V)=[100000]\text{head}_1 = \text{softmax}(\text{scores}_1) \cdot (V W_1^V) = \begin{bmatrix} 1 & 0 \\ 0 & 0 \\ 0 & 0 \end{bmatrix}head1=softmax(scores1)⋅(VW1V)=100000
类似地,我们可以计算出第2个头和第3个头的结果,每个头会关注不同的关系模式。
2. 结果拼接:汇总多视角信息
将所有头的输出结果按列拼接,形成一个综合了多种视角的特征矩阵:
concat_heads=Concat(head1,head2,…,headh)\text{concat\_heads} = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)concat_heads=Concat(head1,head2,…,headh)
在我们的"猫 吃 鱼"例子中,3个头的输出拼接后:
concat_heads=[head1[1]head1[2]head2[1]head2[2]head3[1]head3[2]head1[3]head1[4]head2[3]head2[4]head3[3]head3[4]head1[5]head1[6]head2[5]head2[6]head3[5]head3[6]]\text{concat\_heads} = \begin{bmatrix}
\text{head}_1[1] & \text{head}_1[2] & \text{head}_2[1] & \text{head}_2[2] & \text{head}_3[1] & \text{head}_3[2] \\
\text{head}_1[3] & \text{head}_1[4] & \text{head}_2[3] & \text{head}_2[4] & \text{head}_3[3] & \text{head}_3[4] \\
\text{head}_1[5] & \text{head}_1[6] & \text{head}_2[5] & \text{head}_2[6] & \text{head}_3[5] & \text{head}_3[6]
\end{bmatrix}concat_heads=head1[1]head1[3]head1[5]head1[2]head1[4]head1[6]head2[1]head2[3]head2[5]head2[2]head2[4]head2[6]head3[1]head3[3]head3[5]head3[2]head3[4]head3[6]
3. 线性变换:映射回原始维度
最后,通过一个线性变换矩阵 WOW^OWO 将拼接后的结果映射回模型的原始维度:
output=concat_heads⋅WO\text{output} = \text{concat\_heads} \cdot W^Ooutput=concat_heads⋅WO
这个步骤将多个头的信息整合为一个统一的表示,既保留了多角度分析的优势,又保持了与模型其他部分的维度兼容性。
维度变换详解
假设我们有:
- 序列长度 n=3n = 3n=3(例如"猫 吃 鱼")
- 模型维度 dmodel=6d_{\text{model}} = 6dmodel=6
- 头数 h=3h = 3h=3
- 每个头的维度 dk=dv=6/3=2d_k = d_v = 6 / 3 = 2dk=dv=6/3=2
各步骤的维度变换如下:
-
初始输入:
Q,K,V∈R3×6Q, K, V \in \mathbb{R}^{3 \times 6}Q,K,V∈R3×6 -
每个头的投影(从6维到2维):
QWiQ,KWiK,VWiV∈R3×2Q W_i^Q, K W_i^K, V W_i^V \in \mathbb{R}^{3 \times 2}QWiQ,KWiK,VWiV∈R3×2 -
每个头的输出:
headi∈R3×2\text{head}_i \in \mathbb{R}^{3 \times 2}headi∈R3×2 -
拼接所有头(3个头 × 2维 = 6维):
concat_heads∈R3×(3×2)=R3×6\text{concat\_heads} \in \mathbb{R}^{3 \times (3 \times 2)} = \mathbb{R}^{3 \times 6}concat_heads∈R3×(3×2)=R3×6 -
最终输出(保持原始维度):
output∈R3×6\text{output} \in \mathbb{R}^{3 \times 6}output∈R3×6
这种设计确保了多头注意力可以无缝集成到Transformer架构中,同时通过多个子空间的并行计算提升了模型能力。
多头注意力的优势
-
多角度特征捕捉:不同头可以学习不同类型的关系模式
例如在处理句子"他用电脑写代码"时:
- 头1可能关注"他"与"写"的主谓关系
- 头2可能关注"写"与"代码"的动宾关系
- 头3可能关注"用"与"电脑"的工具关系
-
模型容量提升:在保持计算复杂度可控的情况下增加了模型表达能力
-
注意力分布合理化:避免单一注意力分布过于集中或分散
-
正则化效果:多个头的并行计算相当于引入了隐式正则化,提高了模型泛化能力
总结
多头注意力机制通过以下三个步骤实现了多角度信息捕捉:
- 分头处理:将Q、K、V投影到多个子空间,每个头独立计算注意力
- 结果拼接:汇总所有头的输出,整合多视角信息
- 线性变换:将拼接结果映射回原始维度,保持与模型其他部分的兼容性
这种机制让Transformer能够同时关注序列中不同类型的关系,大大提升了模型对复杂语言结构的理解能力,是现代Transformer架构不可或缺的核心组件。