transformer-注意力评分函数
目录
10.2 节使用了高斯核来对查询和键之间的关系建模,10.6中的高斯核指数部分可以视为注意力评分函数,简称评分函数,然后把这个函数的输出结果输入softmax 函数中进行运算,通过上述步骤,将得到与键对应的值的概率分布,最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。
评分函数 注意力权重 输出
键 softmax 值
查询
图10-4 计算注意力汇聚的输出为值的加权和
用数学语言描述,假设有一个查询q 属于 Rq和m个键值对(ki, ,,,v1),,,,(km,Vm) 其中ki属于Rk
Vi属于Rv,注意力汇聚函数f就被表示成值的加权和
f(q,(k1, V1)) = Sigma a(q,ki) Vi属于Rv
其中,查询q和键Ki的注意力权重是通过注意力评分函数a将两个向量映射成标量再经过softmax运算得到的。
正如图10-4所示,选择不同的注意力评分函数a会导致不同的注意力汇聚操作,本节将介绍两个流行的评分函数,稍后将用它来实现更复杂的注意力机制。
import math
import torch
from torch import nn
from d2l import torch as d2l
10.3.1 掩蔽softmax操作
上面提到的,softmax操作用于输出一个概率分布为注意力权重,在某些情况下,并非所有的值都应该被纳入注意力汇聚中。例如,为了在9.5节中高校处理小批量数据集,某些文本序列被填充了没有意义的特殊词元。为了仅将有意义的词元作为值来获取注意力汇聚,可以指定一个有效序列长度,以便在计算softmax时过滤掉超出指定范围的位置,下面的masked_softmax函数实现了这样的掩蔽softmax操作,其中任何超出有效长度的位置都被掩蔽并设置为0.
def masked_softmax(X, valid_lens):
通过在最后一个轴上掩蔽元素来执行softmax操作
X:3D张量,valid_lens:1D或2D张量
if valid_lens is None:
return nn.functional.softmax(X, dim = -1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
最后一个轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
X=d2l.sequence_mask(X.reshape(-1, shape[-1], valid_lens, value=-1e6))
return nn.functional_softmax(X.reshape(shape), dim=-1)
为了掩饰此函数时如何工作的,考虑由两个2x4矩阵表示的样本,这两个样本有效长度分别为2和3,经过掩蔽softmax操作,超出有效长度的值都被掩蔽为0
masked_softmax(torch.rand(2,2,4), torch.tensor([2,3]))
同样,也可以使用二维张量,为矩阵样本中的每一行指定有效长度。
masked_softmax(torch.rand(2,2,4), torch.tensor([1,3],[2,4]))
10.3.2 加性注意力
当查询和键不同长度的向量时,可以使用加性注意力作为评分函数,给定查询q属于Rq和键k属于Rk,加性注意力的评分函数为
可学习的参数是Wq属于Rkxq,Wk属于Rhxk和Wt属于Rh。 查询和键连接起来后输入一个多层感知机MLP中,感知机包含一个隐藏层,其隐藏单元数一个超参数h,通过使用tanh作为激活函数,并且禁用偏置项。
下面来实现加性注意力
class AdditiveAttention(nn.Module)
加性注意力
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.Linear(key_size, num_hiddens, bias = False)
self.W_q = nn.Linear(query_size, num_hiddens, bias = False)
self.W_v = nn.Linear(num_hiddens, 1, bias = False)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
#在维度扩展后
#queries的形状为(batch_size, 查询数,1,num_hidden)
#key 的形状为(batch_size, 1, 键-值对数,num_hiddens)
#使用广播方式求和
features = queries.unsqueeze(2) + keys.unqueeze(1)
features = torch.tanh(features)
#self.w_v仅有一个输出,因此从形状中移除最后的维度
scores的形状为(batch_size, 查询数,键-值对数)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
values 的形状为batch_size, 键-值对数,值的维度
return torch.bmm(self.dropout(self.attention_weights), values)
用一个小例子演示上面的additiveAttention类,其中查询,键和值的形状为量大小,步数或词元序列长度,实际输出为2,1,20,注意力汇聚输出的形状为 批量大小,查询的步数,值的维度。
queries, keys = torch.normal(0,1,(2,1,20)), torch.ones(2, 10, 2)
#values的小批量,两个值矩阵是相同的
values = torch.arange(40, dtype = torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])
attention = AdditiveAttention(key_size = 2, query_size=20, num_hiddens=8, dropout = 0.1)
attention.eval()
attention(queries, keys, values, valid_lens)
尽管加性注意力包含了可学习的参数,由于本例中每个键都是相同的,因此注意力权重是均匀的,由指定的有效长度决定。
d2l.show_heatmaps(attention.attention_weights.reshape(1,1,2,10)):
xlabel='keys', ylabel='Queries'
10.3.3 缩放点积注意力
使用点积可以得到计算效率更高的评分函数,但是点积操作要求查询和键具有相同的长度d,假设查询和键的所有元素都是独立的随机变量,并且都满足零均值和单位方差,那么两个向量的点积的均值为0,方差为d,为确保无论向量长度如何,点积的方差在不考虑向量长度的情况下都是1,我们再将点积除以 根号d,则缩放点积注意力评分函数为
在实践中,我们通常从小批量的角度来i考虑提高效率,例如基于n个查询的m个键-值对计算注意力,其中查询和键的长度为d, 值的长度为v,查询Q属于Rnxd。
下面的缩放点积注意力的实现使用了暂退法进行模型正则化。
class DotProdductAttention(nn.Module):
缩放点积注意力
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
#queries 的形状为batch_szie,查询数,d
#keys的形状为batch_size, 键-值对数,d
values的形状为batch_size 键-值对数,值的维度
valid_lens的形状为batch_size, 或者batch_size, valid_lens = None
def forward(self, queries, keys, values, valid_lens = None):
d = queries.shape[-1]
设置transpose_b=True是为了交换keys的最后两个维度。
scores = torch.bmm(queries, keys.transpose(1,2))/math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
为了演示上述的DotProductAttention类,我们使用的先前加性注意力例子中相同的键,值和有效长度。对于点积操作,我们令查询的特征维度与键的特征维度大小相同
queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)
与加性注意力演示相同,由于键包含的是相同元素,而这些元素无法通过任何查询进行区分,因此获得了均匀的注意力权重。
d2l.show_heatmaps(attention.attention_weights.reshape(1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries'
小结:注意力汇聚的输出计算可以作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作
当查询和键是不同长度的向量时,可以使用加性注意力评分函数,当他们的长度相同的,使用缩放点积注意力评分函数的计算效率更高。