self attention, masked self attention, cross attention
1. 普通 Self-Attention(缩放点积)
- Q K V是根据同一个输入X
- 没有约束,所有位置都可以互相关注
2. Masked Self-Attention
其中 MM 是 掩码矩阵(mask matrix),定义为:
这意味着:第 ii个位置只能关注第 $1到到i个位置(含自己),不能看到个位置(含自己),不能看到i+1, ..., T$ 的未来 token。
假设序列长度为 4,True
表示被 mask 掉(不可见),False
表示可见。
这个叫做 upper triangular mask(上三角掩码),也叫 causal mask(因果掩码)
3. Cross Attention
Q来自一个源(target), K, V来自另一个源(source)。
Q和K,V的length可能不一样,但是d_model是一样的
代码实现多头注意力机制:
其实可以看到代码和论文是有些出入的,论文是先分头,然后每个head都分别有Wq, Wk, Wv,但是在代码中是先共用Wq, Wk, Wv,然后再分头,只做QKV矩阵乘法,再concatenate
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsassert d_model % num_heads == 0, "d_model must be divisible by num_heads"# Linear projections for Q, K, Vself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def scaled_dot_product_attention(self, Q, K, V, mask=None):# Q, K, V: (batch_size, num_heads, seq_len, d_k)scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attn = F.softmax(scores, dim=-1)return torch.matmul(attn, V)def split_heads(self, x, batch_size):# x: (batch_size, seq_len, d_model)x = x.view(batch_size, -1, self.num_heads, self.d_k) # (batch_size, seq_len, num_heads, d_k)return x.transpose(1, 2) # (batch_size, num_heads, seq_len, d_k)def combine_heads(self, x, batch_size):# x: (batch_size, num_heads, seq_len, d_k)x = x.transpose(1, 2).contiguous() # (batch_size, seq_len, num_heads, d_k)return x.view(batch_size, -1, self.d_model) # (batch_size, seq_len, d_model)def forward(self, Q, K, V, mask=None):batch_size = Q.size(0)Q = self.W_q(Q) # (batch_size, seq_len, d_model)K = self.W_k(K)V = self.W_v(V)Q = self.split_heads(Q, batch_size) # (batch_size, num_heads, seq_len, d_k)K = self.split_heads(K, batch_size)V = self.split_heads(V, batch_size)attn_output = self.scaled_dot_product_attention(Q, K, V, mask)output = self.combine_heads(attn_output, batch_size)return self.W_o(output) # Final linear projection
需要对比一下,看输出一不一样