手写self-attention的三重境界
引言
self-attention在实现过程中有很多细节,不同的面试对self-attention实现的要求也不一样。所以我们要学会多种self-attention实现的方式,以此来告诉面试官,我们是了解self-attention的细节的。
self-attention的公式
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}( \frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V
代码实现
第一重境界:简化版本
import math
import torch
import torch.nn as nnclass SelfAttentionV1(nn.Module):def __init__(self, hidden_dim: int = 728) -> None:super().__init__()self.hidden_dim = hidden_dim# 初始化三个不同的线性应用层self.query_proj = nn.Linear(hidden_dim, hidden_dim)self.key_proj = nn.Linear(hidden_dim, hidden_dim)self.value_proj = nn.Linear(hidden_dim, hidden_dim)def forward(self, x):# x shape is: (batch_size, seq_len, hidden_dim)# 获取不同的Q, K, VQ = self.query_proj(x)K = self.key_proj(x)V = self.value_proj(x)# Q, K, V shape: (batch_size, seq_len, hidden_dim)# (batch_size, seq_len, hidden_dim) * (batch_size, hidden_dim, seq_len) = (batch_size, seq_len, seq_len)attention_value = torch.matmul(Q, K.transpose(-1, -2))# 计算注意力分数attention_weight = torch.softmax(attention_value / math.sqrt(self.hidden_dim), dim=-1)# 计算结果 shape: (batch_size, seq_len, hidden_dim)output = torch.matmul(attention_weight, V)return output
第一重境界比较简单,完全对着公式实现就可以了。
第二重境界:效率优化
对QKV矩阵进行合并,然后再拆分
class SelfAttentionV2(nn.Module):def __init__(self, hidden_dim):super().__init__()self.hidden_dim = hidden_dimself.proj = nn.Linear(hidden_dim, hidden_dim * 3)def forward(self, x):# X shape: (batch_size, seq_len, hidden_dim)# QKV shape (batch_size, seq_len, hidden_dim*3)QKV = self.proj(X)Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)attention_weight = torch.softmax(torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_dim), dim=-1)output = attention_weight @ Vreturn output
第三重:加入一些细节(面试写法)
除了公式外,还有一些细节:
- 加入dropout
- 每个句子长度不一样,加入attention mask
- output矩阵映射
class SelfAttentionV3(nn.Module):def __init__(self, hidden_dim, dropout_rate=0.1):super().__init__()self.hidden_dim = hidden_dimself.proj = nn.Linear(hidden_dim, hidden_dim * 3)self.attention_dropout = nn.Dropout(dropout_rate)self.output_proj = nn.Linear(hidden_dim, hidden_dim)def forward(self, x, attention_mask=None):# x shape: (batch_size, seq_len, hidden_dim)QKV = self.proj(x)Q, K, V = torch.split(QKV, self.hidden_dim, dim=-1)attention_weight = Q @ K.transpose(-1, -2) / math.sqrt(self.hidden_dim)# 如果attention_mask不是None,那就要给那些被mask掉的词语一个非常非常小的值,这样做完softmax以后这些值就是0if attention_mask is not None:attention_weight = attention_weight.masked_fill(attention_mask == 0,float("-1e20"))attention_weight = torch.softmax(attention_weight, dim=-1)# 做dropoutattention_weight = self.attention_dropout(attention_weight)attention_result = attention_weight @ Voutput = self.output_proj(attention_result)return output
从 V1 到 V3 的核心优化脉络(迭代逻辑)
- 第一阶段:工程效率优化(V1 → V2)
- 优化点:将 3 个独立线性层合并为 1 个合并线性层,再 split 拆分 QKV。
- 核心逻辑:数学上完全等价(仅权重拼接),但减少了内核启动次数、内存碎片化,提升硬件并行效率(GPU 更易利用批量矩阵乘法算力)。
- 价值:从 “教学级冗余实现” 转向 “工程化高效实现”,无性能损失,仅提升效率。
- 第二阶段:功能完整性优化(V2 → V3)
- 优化点 1:新增 attention_mask 支持(修正笔误后用 masked_fill)。
- 解决问题:适配实际场景(NLP 批量 Padding、生成任务因果掩码),屏蔽无效位置干扰。
- 优化点 2:新增注意力权重 Dropout。
- 解决问题:正则化,防止模型过度依赖少数关键位置,缓解过拟合。
- 优化点 3:新增输出线性投影 output_proj。
- 解决问题:对注意力聚合后的特征做 “精炼”,增强模型表达能力,适配深层网络堆叠。
- 价值:从 “仅追求效率” 转向 “可落地工业级功能”,覆盖批量训练、泛化能力等核心需求。
