【大语言模型】—— 自注意力机制及其变体(交叉注意力、因果注意力、多头注意力)的代码实现
【大语言模型】—— 注意力机制及其变体的代码实现
- 摘要
- Self-Attention
- 为什么 dim=-1
- CrossAttention
- 交叉注意力的作用
- CausalAttention
- unsqueeze(0)的作用
- MultiHeadSelfAttention
- 多头自注意力(Multi-Head Attention)的核心思想
摘要
本文介绍了注意力机制的几种变体及其PyTorch代码实现。主要包括:
- Self-Attention:基础自注意力机制,通过Q、K、V计算注意力权重,适用于序列内部建模。
- CrossAttention:让一个序列关注另一个序列,典型应用于Transformer解码器-编码器交互和多模态任务。
- CausalAttention:通过三角掩码实现因果性,确保只能关注当前位置之前的token,适用于自回归生成任务。
- MultiHeadSelfAttention:多头注意力机制,将输入分割到多个子空间并行计算注意力,最后合并结果。
代码实现中详细展示了各注意力的关键操作,包括线性变换、注意力分数计算、softmax归一化和掩码处理等。特别解释了dim=-1的作用、unsqueeze(0)的广播机制等实现细节。
Self-Attention
class SelfAttention(nn.module):
# torch.matmul 专用于批量矩阵乘法,适用于形状为 (batch_size, n, m) 和 (batch_size, m, p) 的 3D 张量。
# torch.matmul 支持更灵活的张量乘法运算def __init__(self, input_dim, dim_k,dim_v):super().__init__()self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)selv.v = nn.Linear(input_dim, dim_v)self.scale = np.sqrt(dim_k)def forward(self, x):Q = self.q(x)K = self.k(x)V = self.v(x)atten = torch.softmax(torch.matmul(Q, K.permute(0,2,1))/self.scale, dim=-1)return torch.matmul(atten, V)
为什么 dim=-1
在自注意力机制中,nn.Softmax(dim=-1)
的作用是对 注意力分数矩阵 进行归一化,使得每一行的权重之和为 1。这里 dim=-1
表示在最后一个维度(即 seq_len
维度)上进行 Softmax
计算。
在 Q K T QK^T QKT计算后,得到的注意力分数矩阵的形状是 [batch_size, seq_len, seq_len]
,其中:
第 1 个 seq_len(dim=1)
:代表 Q的序列长度(即当前 token 的位置)。
第 2 个 seq_len(dim=2)
:代表 K的序列长度(即被计算注意力的 token 的位置)。
dim=-1
(即 dim=2
)表示 对每个 token
计算它对所有 token
的注意力权重,即 对每一行进行 Softmax
,使得:
每一行的所有值加起来等于 1(概率分布)。
这样,每个 token 的注意力权重是独立计算的。
假设 Q K T QK^T QKT的结果是:
[[[1.0, 0.5, 0.2], # Token 0 对所有 token 的注意力分数[0.3, 1.2, 0.7], # Token 1 对所有 token 的注意力分数[0.1, 0.4, 1.5]] # Token 2 对所有 token 的注意力分数
]
应用 nn.Softmax(dim=-1)
后:
[[[0.55, 0.27, 0.18], # Token 0 的注意力权重(总和=1)[0.16, 0.58, 0.26], # Token 1 的注意力权重(总和=1)[0.07, 0.20, 0.73]] # Token 2 的注意力权重(总和=1)
]
这样,每个 token 的注意力权重是独立的,且所有 token 对它的影响权重之和为 1。
CrossAttention
# 查询通常来自解码器,键和值通常来自编码器
import torch
import torch.nn as nn
import numpy as npclass CrossAttention(nn.Module):def __init__(self, input_dim, dim_k, dim_v):super().__init__() # 必须调用父类初始化self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self.scale = np.sqrt(dim_k)def forward(self, x1, x2):Q1 = self.q(x1) # [batch_size, seq_len1, dim_k]K2 = self.k(x2) # [batch_size, seq_len2, dim_k]V2 = self.v(x2) # [batch_size, seq_len2, dim_v]# 计算注意力分数atten = torch.softmax(torch.matmul(Q1, K2.permute(0, 2, 1)) / self.scale, dim=-1) # [batch_size, seq_len1, seq_len2]# 加权求和return torch.matmul(atten, V2) # [batch_size, seq_len1, dim_v]
交叉注意力的作用
交叉注意力用于 让一个序列 x 1 x1 x1关注另一个序列 x 2 x2 x2,典型应用包括:
- Transformer 解码器:
- x1= 解码器的输入(当前生成的 token)
- x2= 编码器的输出(源序列的表示)
- 解码器通过交叉注意力关注编码器的信息。
- 多模态任务(如视觉-语言模型):
- x1= 文本序列
- x2= 图像特征
- 文本通过交叉注意力关注图像的关键区域。
CausalAttention
class CausalAttention(nn.Module):def __init__(self,input_dim, dim_k,dim_v):super().__init__()self.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self.scale = np.sqrt(dim_k)def forward(self, x):# x: [batch_size, seq_len, input_dim]Q = self.q(x) # [batch_size, seq_len, dim_k]K = self.k(x) # [batch_size, seq_len, dim_k]V = self.v(x) # [batch_size, seq_len, dim_v]# 注意力分数atten = torch.matmul(Q, K.permute(0, 2, 1)) / self.scale # [batch, seq, seq]# 下三角 mask,确保因果性(只能看到之前的token)seq_len = atten.size(-1)mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)atten = atten.masked_fill(mask == 0, float('-inf'))# softmax 归一化atten = nn.Softmax(dim=-1)(atten)# 输出加权和return torch.matmul(atten, V) # [batch, seq, dim_v]
unsqueeze(0)的作用
假设我们有一个 2D 张量 mask,形状是 [seq_len, seq_len]
mask = torch.tril(torch.ones(seq_len, seq_len)) # 形状 [seq_len, seq_len]
如果我们想让它变成 [1, seq_len, seq_len]
(即增加一个 batch 维度),可以使用:
mask = mask.unsqueeze(0) # 形状变为 [1, seq_len, seq_len]
这样做的目的是:
- 匹配注意力分数矩阵的形状(
atten
的形状是[batch_size, seq_len, seq_len]
)。 - 支持批量计算,因为
mask
需要广播到所有batch
样本。
这样:
mask的形状变成 [1, seq_len, seq_len]
。
PyTorch 会自动广播 mask到 [batch_size, seq_len, seq_len]
,使其与 atten形状匹配。
MultiHeadSelfAttention
class MultiHeadAttention(nn.Module):def __init__(self, heads, input_dim, dim_k, dim_v):super().__init__()self.heads = headsself.dim_k_per_head = dim_k // headsself.dim_v_per_head = dim_v // headsself.q = nn.Linear(input_dim, dim_k)self.k = nn.Linear(input_dim, dim_k)self.v = nn.Linear(input_dim, dim_v)self.scale = np.sqrt(self.dim_k_per_head)self.out = nn.Linear(dim_v, input_dim)def forward(self, x):batch_size = x.size(0)Q = self.q(x)#[batch_size, seq_len, dim_k]K = self.k(x)V = self.v(x)#[batch_size, seq_len, heads, dim_k_per_head] # --> [batch_size, heads, seq_len, dim_k_per_head]Q = Q.view(batch_size, -1, self.heads, self.dim_k_per_head).permute(0,2,1,3)#[batch_size, seq_len, heads, dim_k_per_head] # --> [batch_size, heads, seq_len, dim_k_per_head]K = K.view(batch_size, -1, self.heads, self.dim_k_per_head).permute(0,2,1,3)#[batch_size, seq_len, heads, dim_v_per_head] # --> [batch_size, heads, seq_len, dim_v_per_head]V = V.view(batch_size, -1, self.heads, self.dim_v_per_head).permute(0,2,1,3)#转置[batch_size, heads, seq_len, dim_k_per_head] # --> [batch_size, heads, dim_k_per_head, seq_len]K = K.permute(0, 1, 3, 2)# [batch_size, heads, seq_len, seq_len]atten = torch.softmax(torch.matmul(Q,K) / self.scale, dim = -1)# [batch_size, heads, seq_len, dim_v_per_head]out = torch.matmul(atten, V)# [batch_size, seq_len, heads, dim_v_per_head]out = out.permute(0, 2, 1, 3).contiguous()# [batch_size, seq_len, heads* dim_v_per_head]out = out.view(batch_size, -1, self.heads * self.dim_v_per_head)return self.out(out) # [batch_size, seq_len, input_dim]
多头自注意力(Multi-Head Attention)的核心思想
多头自注意力(Multi-Head Attention)的核心思想是将输入向量分别映射为查询 Q、键 K、值 V,再按照头数切分到多个子空间中;每个头独立计算注意力分数并得到加权表示,最后拼接各头的结果,通过线性层 out
映射回输入维度,从而捕捉序列中多角度的相关性。
在使用时需要注意以下几点:
- 维度整除:要确保
dim_k
和dim_v
能被heads
整除,否则view
时会报错。 - 缩放因子:缩放应该基于每个头的维度
sqrt(dim_k_per_head)
,而不是整体的dim_k
。 - Softmax 顺序:正确做法是
Softmax(QK^T / scale)
,不要写成Softmax(QK^T) / scale
。 - 张量连续性:
permute
之后用.contiguous().view()
,否则可能报错;或者用reshape
自动处理。 - 输出层作用:
self.out
的作用是把多头拼接后的结果重新映射回输入维度,保持层间维度一致。