深入理解Self-Attention - 原理与等价表示
概述
很多文章或论文已经很好的解释了神经网络中self-attention的原理,但是个人觉得还是有其他可解释的方面,主要原因是很多解释都是面向过程的,只解释了它是什么样的,这篇文章主要从其等价形式解释其原理。
Self-Attention原理
普适的self-attention的公式为:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
1
d
k
Q
K
T
)
V
Attention(Q,K,V) = softmax(\frac{1}{\sqrt{d_k} } QK^\mathsf{T})V
Attention(Q,K,V)=softmax(dk1QKT)V其中Q,K,V分别表示Query,Key,Value,且
Q
=
x
W
q
Q = xW_q
Q=xWq
K
=
x
W
k
K = xW_k
K=xWk
V
=
x
W
v
V = xW_v
V=xWv即Q,K,V都是
x
x
x 经过线性变换后的数据,其中
x
∈
R
m
×
n
;
W
q
,
W
k
,
W
v
∈
R
n
×
j
x\in R^{m\times n};W_q,W_k,W_v\in R^{n\times j}
x∈Rm×n;Wq,Wk,Wv∈Rn×j,即每一个token为
n
n
n 维向量,
x
x
x是由
m
m
m 个单词构成的句子。而
1
d
k
\frac{1}{\sqrt{d_k}}
dk1是避免矩阵乘值过大的超参数,本质上说self-attention的核心内容为:
A
t
t
e
n
t
i
o
n
(
x
)
=
s
o
f
t
m
a
x
(
x
W
q
(
x
W
k
)
T
)
V
Attention(x) = softmax(xW_q(xW_k)^\mathsf{T})V
Attention(x)=softmax(xWq(xWk)T)V
Self-Attention的等价表示
这里为了简便,我们先解释在softmax内的矩阵的等价形式即
M
=
x
W
q
(
x
W
k
)
T
M=xW_q(xW_k)^\mathsf{T}
M=xWq(xWk)T假设存在
W
~
\tilde{W}
W~其中
W
~
∈
R
m
×
m
\tilde{W}\in R^{m\times m}
W~∈Rm×m使得以下等式成立
M
=
x
W
q
(
x
W
k
)
T
=
x
x
T
⊙
W
~
M=xW_q(xW_k)^\mathsf{T}=xx^\mathsf{T}\odot\tilde{W}
M=xWq(xWk)T=xxT⊙W~其中
⊙
\odot
⊙表示哈达玛积(Hadamard product,即逐点乘积),我们可以计算出等式左边和右边得到都是
m
×
m
m\times m
m×m的矩阵,根据哈达玛逆的性质:一个方阵的逆乘以其自身为单位矩阵,所以我们可以两边左乘
(
x
x
T
)
⊙
−
1
(xx^\mathsf{T})^{\odot-1}
(xxT)⊙−1,得到
W
~
=
(
x
x
T
)
⊙
−
1
⊙
(
x
W
q
(
x
W
k
)
T
)
\tilde{W} = (xx^\mathsf{T})^{\odot-1}\odot(xW_q(xW_k)^\mathsf{T})
W~=(xxT)⊙−1⊙(xWq(xWk)T) 这里需要矩阵
x
x
T
xx^\mathsf{T}
xxT的任意元素不为0,即
[
x
x
T
]
i
j
≠
0
[xx^\mathsf{T}]_{ij}\ne0
[xxT]ij=0,由于参数是可学习的所以原始矩阵
M
M
M中的参数可以使用
W
~
\tilde{W}
W~来代替
W
q
,
W
k
W_q,W_k
Wq,Wk。
当然我们也可以使用低秩矩阵来表示
W
~
\tilde{W}
W~,即
W
~
=
W
a
W
b
T
\tilde{W}=W_aW_b^{\mathsf{T}}
W~=WaWbT 其中
W
a
,
W
b
∈
R
m
×
j
W_a,W_b\in R^{m\times j}
Wa,Wb∈Rm×j,所以我们可以得到另一种表示
M
=
x
x
T
⊙
(
W
a
W
b
)
T
M = xx^\mathsf{T}\odot(W_aW_b)^{\mathsf{T}}
M=xxT⊙(WaWb)T 为了简单,这里我们不写成低秩的形式,所以self-attention可以写为
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
1
d
k
x
x
T
⊙
W
~
)
V
Attention(Q,K,V) = softmax(\frac{1}{\sqrt{d_k} } xx^\mathsf{T}\odot\tilde{W})V
Attention(Q,K,V)=softmax(dk1xxT⊙W~)V 我们可以看到,其实Attention中本质上关键的是方阵
x
x
T
xx^\mathsf{T}
xxT,右边的参数
W
~
\tilde{W}
W~只是对该方阵的线性变换。
import torch
k,m,n = 2,3,4
x = torch.randn(m,n)
wq = torch.randn(n,k)
wk = torch.randn(n,k)
Q = x@wq
K = x@wk
print("Q@K.T:",Q@K.T)
M = (1.0/(x@x.T))*(Q@K.T)
print("x@x.T*M:",x@x.T*M)
对 x x T ⊙ W ~ xx^\mathsf{T}\odot\tilde{W} xxT⊙W~ 的解释
这里我们用简单的例子来阐述
x
x
T
xx^\mathsf{T}
xxT ,假设存在一个以
m
=
3
m=3
m=3个单词组成的句子
x
=
[
a
,
b
,
c
]
T
x = [a,b,c]^{\mathsf{T}}
x=[a,b,c]T 其中
a
,
b
,
c
∈
R
n
a,b,c\in R^n
a,b,c∈Rn ,那么自然
x
∈
R
3
×
n
x \in R^{3 \times n}
x∈R3×n,我们可以得到
x
x
T
=
[
a
a
a
b
a
c
b
a
b
b
b
c
c
a
c
b
c
c
]
xx^\mathsf{T} = \begin{bmatrix} aa& ab&ac \\ ba& bb&bc \\ ca& cb&cc \end{bmatrix}
xxT=
aabacaabbbcbacbccc
这里向量
a
a
,
a
b
,
.
.
.
,
c
c
aa,ab,...,cc
aa,ab,...,cc 为两个向量的点积,而点积也可以表示为如下形式
a
⋅
b
=
∥
a
∥
∥
b
∥
c
o
s
(
θ
)
a\cdot b = \left \| a \right \| \left \| b \right \| cos(\theta)
a⋅b=∥a∥∥b∥cos(θ) 反过来
a
⋅
b
∥
a
∥
∥
b
∥
=
c
o
s
(
θ
)
\frac{a\cdot b}{\left \| a \right \| \left \| b \right \|} = cos(\theta)
∥a∥∥b∥a⋅b=cos(θ)所以点积可以表示非归一化的相似性,或者准确来说是表示的相关性。一般在神经网络中不用除以范数来表示相关性,经验性的结果来说,一是收敛比较困难、二是计算量增加了不少。而直接用点积也有其局限性(静态的相关性),所以我们可以带参数的形式,即可学习的相关性:
a
⋅
b
⋅
w
≈
c
o
s
(
θ
)
a\cdot b\cdot w \approx cos(\theta)
a⋅b⋅w≈cos(θ) 这里
≈
\approx
≈只是表示一个近似操作而非约等于,所以本质上Attention中
x
x
T
⊙
W
~
xx^\mathsf{T}\odot\tilde{W}
xxT⊙W~表示的是学习后句子中各单词的相关关系程度矩阵,需要说明的是这种相关性是学习到的,而非像
c
o
s
i
n
e
cosine
cosine静态的关系(比如这句话:“我喜欢小狗,它们很可爱”,token“小狗”与“小狗”的cosine相似度一定是最高的,而学习到的这种相关性并不一定是最高的,而可能是“小狗”与“他们”有更高的相关性)。
我们也要注意到 当句子很长时 W ~ \tilde{W} W~ 矩阵是 m × m m\times m m×m的,所以参数量将非常大,容易过拟合,而且计算也大了很多,所以等价和等计算是完全不同的,原始的attention中参数 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv 本质是(token级别)共享的变换矩阵,类似于CNN中的卷积。
对 1 d k \frac{1}{\sqrt{d_k} } dk1的解释
这个 1 d k \frac{1}{\sqrt{d_k} } dk1其实是在解决softmax归一化存在的问题,或者说是指数函数存在的问题。
即假设一个向量 x = [ a , b , c ] x=[a,b,c] x=[a,b,c],其中各元素不等,我们将 x x x放大 s s s倍,这时随着放大倍数的增大, y = s o f t m a x ( x ) y=softmax(x) y=softmax(x)中向量中某一变量将趋向于1,换句话说,softmax不再是soft max而是hard max。而如果与Attention中的 V V V作点积,那么Attention就变为了简单的取某一个token,而不是token的加权混合。
import torch
x = torch.randn(3)
print(torch.softmax(x,dim=0))
print(torch.softmax(x*100,dim=0))
从连接的图结构看attention
实际上不管是MLP中的全连接层,还是CNN中的卷积操作,本身表示的是变量(或token)与变量聚合的操作,如:
y
=
w
1
x
1
+
w
2
x
2
+
b
y = w_1x_1+w_2x_2+b
y=w1x1+w2x2+b而attention不同,它首先所表示的是变量与变量间的关系,以及根据变量间关系的特征聚合的操作,某种程度来说它与全连接或卷积的连接方式有根本的不同。
总结
我们通过self-attetion的等价形式
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
1
d
k
x
x
T
⊙
W
~
)
V
Attention(Q,K,V) = softmax(\frac{1}{\sqrt{d_k} } xx^\mathsf{T}\odot\tilde{W})V
Attention(Q,K,V)=softmax(dk1xxT⊙W~)V 可以更容易解释attention的作用,即
1
d
k
x
x
T
\frac{1}{\sqrt{d_k}} xx^\mathsf{T}
dk1xxT 表示句子内各单词的相关关系程度矩阵,而
W
~
\tilde{W}
W~表示对相关关系的线性变换,softmax则是对变换后的相关关系在句子内的归一化,乘以
V
V
V则是对线性变换后的token作加权特征混合(feature mixing)。
参考
- Attention Is All You Need