Transformer中为什么要使用多头注意力?
参考视频:面试必刷:大模型为什么要使用多头注意力?_哔哩哔哩_bilibili
详解文章:Transformer内容详解(通透版)-CSDN博客
单头注意力的劣势:单头注意力只能从一个角度“看”输入序列,计算得到的注意力权重反映的是一种特定的关注模式。
多头注意力将注意力分为了多个“头”,每个头独立计算注意力,关注输入的不同子空间或不同方面的特征。这样,模型能够并行地捕捉到多种不同类型的语义关系。
将输入投射到多个不同的低维空间,分别计算注意力,最后再concat拼接,通过线性变换融合,丰富了模型的表达能力,使得Transformer能够学习复杂的组合特征。同时,每个注意力头的参数量和计算复杂度降低,有助于提升训练的稳定性和效率,有利于收敛。
单头注意力:
import torch
import torch.nn as nnclass Self_Attention(nn.Module):def __init__(self, dim, dk, dv):super().__init__()self.scale = dk ** -0.5self.q = nn.Linear(dim, dk)self.k = nn.Linear(dim, dk)self.v = nn.Linear(dim, dv)def forward(self, x):# x: [batch, seq_len, dim]q = self.q(x) # [batch, seq_len, dk]k = self.k(x) # [batch, seq_len, dk]v = self.v(x) # [batch, seq_len, dv]attn = (q @ k.transpose(-2, -1)) * self.scale # [batch, seq_len, seq_len]attn = attn.softmax(dim=-1)out = attn @ v # [batch, seq_len, dv]return out
多头注意力:
import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, dim, dk, dv, num_heads):super().__init__()self.num_heads = num_headsself.dk = dkself.dv = dvself.q_linear = nn.Linear(dim, dk * num_heads)self.k_linear = nn.Linear(dim, dk * num_heads)self.v_linear = nn.Linear(dim, dv * num_heads)self.out_linear = nn.Linear(dv * num_heads, dim)def forward(self, x):B, N, _ = x.shape # batch, seq_len, dimQ = self.q_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2) # [B, heads, N, dk]K = self.k_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2)V = self.v_linear(x).view(B, N, self.num_heads, self.dv).transpose(1, 2)# Attentionattn = (Q @ K.transpose(-2, -1)) / (self.dk ** 0.5) # [B, heads, N, N]attn = attn.softmax(dim=-1)out = attn @ V # [B, heads, N, dv]out = out.transpose(1, 2).reshape(B, N, self.num_heads * self.dv) # [B, N, heads*dv]out = self.out_linear(out) # [B, N, dim]return out