一文详解 transformer 中的 self-attention
无反馈,不学习
要完全理解清楚transformer结构,就必须理解self-attention是怎么运作的。在transformer中,其最核心的改进就是引入了self-attention网络结构。
一)self-attention网络结构
self-attention网络结构的作用如下图所示,其作用就是将输入的向量经过self-attention处理后输出相同数量,相同大小的向量。
但是需要注意的是,输出的,并不是类似于普通的神经网络的处理和输出。 对应输出的
是考虑了所有输入的全局信息得到的,可能是类似于语义信息等;暂且这里我们将其叫为是加了注意力的信息
。
1.1 如何给
加注意力
self-attention给其加注意力信息。其做的第一件事情是计算了他们之间的相关性。
第一步:计算他们之间的相关性,至于怎么计算,这个问题先留着。
第二步:得到他们之间的相关性后的,以得到
举例,
的值就是被加入了注意力信息的值,
的值是根据
,
,
,
之间的相关性(这里的相关性也被叫做注意力)去对应的
中抽取信息,若
之间的相关性(注意力)强,说明网络应该更加关注
,即会根据得到的相关性抽取信息。至于如何抽取,也先留着。
1.2 计算
他们之间的相关性
对于计算计算他们之间的相关性的方法有很多,这里讲解最常用的方式——Dot-product方式。
假设要计算之间的注意力,首先第一步,将
×上
矩阵 得到 q , 将
×上
矩阵 得到 k,q 和 k 再进行点积运行算得到的就是
之间的注意力。那
和
是什么呢?这里暂且默认
和
就是存在,通过上述运算就能得到二者的相关性。
1.3 如何根据注意力抽取重要信息
上诉只是得到了之间的注意力,我们继续来看下如何得到
。
如上图所示, 通过乘上
矩阵可以得到对应的
,同样的
通过乘上
矩阵可以得到对应的
,得到了q 和 k 的值,进而通过Dot-product方式就可以计算得到对应的相关性,即
,
,
,
。然后让
,
,
,
通过softmax的处理,得到最后的相关性
,
,
,
。
通过乘上
矩阵可以得到对应的
,这里暂且默认
就是存在,这个时候就可以根据注意力
,
,
,
去抽取重要信息
了。
所以 计算公式如下:
同理,就也可以得到对应的
。
至此,遗留的两个问题,计算他们之间的相关性和根据得到的相关性抽取重要信息已解决。现在我们唯一还不知道的是
到底是什么。其实还有一个小问题就是,在得到最后的注意力
,
,
,
的时候,我们是使用了一个softmax函数来进行处理,至于为什么要使用这个softmax函数来进行处理也是需要回答的,这里就不讲解了,只需要去深入了解一下softmax函数的作用即可理解。
二) 手动进行self-attention
以下三个输入向量,将其构成一个矩阵
X = [[1.0, 0.5, 2.0, 1.5], # 猫[0.5, 2.0, 1.0, 0.0], # 吃[1.5, 1.0, 0.5, 2.0]] # 鱼
进行 self-attention的第一步就是计算输入向量对应的 k q v 。假设
如下所示。
W_q (4x3):[[0.5, 1, 0],[0, 0.5, 1],[1, 0, 0.5],[0.5, 0, 1]]W_k (4x3):[[1, 0, 0.5],[0.5, 1, 0],[0, 0.5, 1],[1, 0, 0.5]]W_v (4x4):[[0, 1, 0.5, 1],[1, 0.5, 0, 1],[0.5, 0, 1, 1],[0, 1, 0.5, 1]]
查询向量 q1
(输入向量为猫 [1.0, 0.5, 2.0, 1.5])
q1 = x1 * W_q
= [1, 0.5, 2, 1.5] * [[0.5, 1, 0], [0, 0.5, 1], [1, 0, 0.5], [0.5, 0, 1]]
- 计算第一个元素:
1*0.5 + 0.5*0 + 2*1 + 1.5*0.5 = 0.5 + 0 + 2 + 0.75 = 3.25
- 计算第二个元素:
1*1 + 0.5*0.5 + 2*0 + 1.5*0 = 1 + 0.25 + 0 + 0 = 1.25
- 计算第三个元素:
1*0 + 0.5*1 + 2*0.5 + 1.5*1 = 0 + 0.5 + 1 + 1.5 = 3.0
q1 = [3.25, 1.25, 3.0]
键向量 k1
(输入向量为猫 [1.0, 0.5, 2.0, 1.5]):
k1 = x1 * W_k
= [1, 0.5, 2, 1.5] * [[1, 0, 0.5], [0.5, 1, 0], [0, 0.5, 1], [1, 0, 0.5]]
-
第一个元素:
1*1 + 0.5*0.5 + 2*0 + 1.5*1 = 1 + 0.25 + 0 + 1.5 = 2.75
-
第二个元素:
1*0 + 0.5*1 + 2*0.5 + 1.5*0 = 0 + 0.5 + 1 + 0 = 1.5
-
第三个元素:
1*0.5 + 0.5*0 + 2*1 + 1.5*0.5 = 0.5 + 0 + 2 + 0.75 = 3.25
k1 = [2.75, 1.5, 3.25]
值向量 v1
(输入向量为猫 [1.0, 0.5, 2.0, 1.5])
v1 = x1 * W_v
= [1, 0.5, 2, 1.5] * [[0, 1, 0.5, 1], [1, 0.5, 0, 1], [0.5, 0, 1, 1], [0, 1, 0.5, 1]]
-
第一个元素:
1*0 + 0.5*1 + 2*0.5 + 1.5*0 = 0 + 0.5 + 1 + 0 = 1.5
-
第二个元素:
1*1 + 0.5*0.5 + 2*0 + 1.5*1 = 1 + 0.25 + 0 + 1.5 = 2.75
-
第三个元素:
1*0.5 + 0.5*0 + 2*1 + 1.5*0.5 = 0.5 + 0 + 2 + 0.75 = 3.25
-
第四个元素:
1*1 + 0.5*1 + 2*1 + 1.5*1 = 1 + 0.5 + 2 + 1.5 = 5
v1 = [1.5, 2.75, 3.25, 5]
这里计算的注意力
=
= 20.5625
同样的方式计算得到,
,
的注意力,然后通过softmax的处理,得到最后的相关性
,
,
,
。 然后在根据得到的相关性抽取重要信息
。
在这里我们是假设了
矩阵的值,其实在真实的网络当中呢,其都是未知的参数,我们只定义它的大小,至于其内部的值都是作为未知参数待求的。
还有一个问题,为什么我们在假设 矩阵的时候,其大小是4×4的,而不是像
一样是4×3的呢?先说结论
由于输入向量是1×4的,为了能让矩阵能运算下去,其
必须是4行的,至于几列不影响能否计算,但
必须也是4列,因为要保证其输入向量经过self-attention处理输出后保持向量相同大小,至于为什么,是应为在transformer模型中,加入了残差连接,必须保证输入输出向量大小一致。
那为什么
可以是三列?其实对于
的列可以自由设定,但是
的列大小相同。
投影维度 (权重矩阵的列数)的影响
维度 | 作用 | 增大影响 | 减小影响 |
---|---|---|---|
注意力分数的区分度 | ✅ 提升细粒度关注能力 ❌ 增加 QKᵀ 计算成本 | ⚠️ 可能丢失语义细节 ✅ 降低内存占用 | |
值向量的信息携带量 | ✅ 增强上下文信息传递 ❌ 增加输出计算量 | ⚠️ 可能压缩关键信息 ✅ 加速输出生成 |