当前位置: 首页 > news >正文

从零打造大语言模型2--编码注意力机制

从零到会用:多头注意力(超通俗 + 逐步实战 + 注释齐全)

读完并跑完本文,你将能:

  1. 说人话解释注意力/自注意力/多头注意力;2) 从零实现单头→多头(两种工程写法);3) 正确处理因果遮挡padding 遮挡;4) 把 MHA 放进一个最小 Transformer Block 并完成一次语言建模训练步;5) 用等价性测试参数量统计给自己的实现做体检。

所有代码均为原创实现,面向小白、循序渐进。


目录

  1. 把问题说人话(注意力到底在干嘛)
  2. 第 0 步:无参数的小玩具(直觉建立)
  3. 第 1 步:可训练的单头自注意力
  4. 第 2 步:因果遮挡padding 遮挡(两种 mask 一次讲清)
  5. 第 3 步:多头注意力·写法 A(Wrapper:复制多份单头再拼接)
  6. 第 4 步:多头注意力·写法 B(工程高效:一次性投影再“拆头”)
  7. 第 5 步:把 MHA 塞进最小 Transformer Block(Pre-LN)
  8. 第 6 步:最小语言模型前向与训练一步(含 label 右移)
  9. 第 7 步:等价性验证输出维度=2参数量统计
  10. 第 8 步:性能与工程选项(SDPA/Flash、qkv_bias、dropout 等)
  11. 常见坑位排查清单(超全)
  12. 发布到 CSDN 的小贴士(配图&排版)

把问题说人话:注意力到底在干嘛?

把每个 token 想成“演员”,它要从全剧里挑出“对我台词最有用的演员”。

  • Q(query):我是谁、我在找什么信息;
  • K(key):别人身上的关键词;
  • V(value):别人提供的真正信息;
  • 打分score = Q · K^T(点积越大,相关性越强);
  • 归一softmax 把每行分数变成权重(每行求和=1);
  • 聚合context = weight @ V 得到“我需要的上下文”。

数值稳定的原因:分母加了 √d_k,让分数在合理范围内,softmax 不会过早饱和。

小抄:“Q 去找 K,向 V 要信息”

Attention整体计算流程

在这里插入图片描述


第 0 步:无参数的小玩具(直觉建立)

下面用随机嵌入演示“看—打分—归一—加权”的三步,不涉及学习参数。

import torch
import torch.nn.functional as F# 一段序列的“嵌入”,B=1 方便看形状
x = torch.tensor([[ [0.4, 0.2, 0.9],[0.6, 0.8, 0.7],[0.2, 0.5, 0.3] ]], dtype=torch.float)  # (B=1, N=3, d=3)Q, K, V = x, x, x  # 无参数版:直接用自身当 Q/K/V
scores = torch.matmul(Q, K.transpose(1,2))            # (1,3,3)
weights = F.softmax(scores / (x.size(-1) ** 0.5), -1) # (1,3,3) 每行和=1
ctx = torch.matmul(weights, V)                         # (1,3,3)
print('scores:', scores.squeeze(0))
print('weights row-sum:', weights.squeeze(0).sum(-1))
print('context:', ctx.shape)

第 1 步:可训练的单头自注意力(含因果遮挡)

一步到位版:三组线性层得到 Q/K/V → 缩放点积 → softmax 前做因果遮挡 → 对 V 加权求和。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass CausalSelfAttention(nn.Module):"""单头因果自注意力(可直接复用)输入: x (B, N, d_in);输出: (B, N, d_out)"""def __init__(self, d_in, d_out, context_len, dropout=0.1, qkv_bias=False):super().__init__()self.d_out = d_outself.q = nn.Linear(d_in, d_out, bias=qkv_bias)self.k = nn.Linear(d_in, d_out, bias=qkv_bias)self.v = nn.Linear(d_in, d_out, bias=qkv_bias)self.dropout = nn.Dropout(dropout)          # 放在注意力权重上# 上三角 mask(对未来位置置 1);注册为 buffer,随模型走但不参与训练self.register_buffer('causal', torch.triu(torch.ones(context_len, context_len), diagonal=1).bool())def forward(self, x):B, N, _ = x.shapeQ = self.q(x)  # (B, N, d_out)K = self.k(x)  # (B, N, d_out)V = self.v(x)  # (B, N, d_out)# (B, N, N)scores = torch.matmul(Q, K.transpose(1, 2)) / (K.size(-1) ** 0.5)# **关键**:softmax 之前做因果遮挡(只取到 N×N)scores = scores.masked_fill(self.causal[:N, :N], float('-inf'))attn = F.softmax(scores, dim=-1)attn = self.dropout(attn)ctx = torch.matmul(attn, V)  # (B, N, d_out)return ctx

第 2 步:两种遮挡一次讲清(因果 vs padding

  • 因果遮挡:防止位置 t 看未来(语言模型必须)。
  • padding 遮挡:当一批序列长度不等,补齐的 pad 位置不应被任何人“看到”。

我们在上面的单头里加入 padding mask 支持:

class CausalSelfAttentionWithPad(CausalSelfAttention):"""在因果遮挡基础上,额外支持 padding mask。pad_mask: (B, N) -> True 表示该位置是 PAD,需要被遮挡。"""def forward(self, x, pad_mask=None):B, N, _ = x.shapeQ = self.q(x); K = self.k(x); V = self.v(x)scores = torch.matmul(Q, K.transpose(1,2)) / (K.size(-1) ** 0.5)  # (B,N,N)scores = scores.masked_fill(self.causal[:N, :N], float('-inf'))if pad_mask is not None:# 把被 padding 的列(被别人看到)屏蔽:在列维度扩展到 (B,1,N)scores = scores.masked_fill(pad_mask[:, None, :].bool(), float('-inf'))attn = F.softmax(scores, dim=-1)attn = self.dropout(attn)return torch.matmul(attn, V)

记忆点:因果屏蔽“未来列”,padding屏蔽“补齐列”。实战里两者常常同时用。


第 3 步:多头注意力 · 写法 A(Wrapper,直观)

思路:复用单头,复制 H 份并行计算,最后在最后一维 cat 再做一次线性融合。

class MultiHeadAttention_Wrapper(nn.Module):"""把多个单头的输出拼接起来,再线性融合。d_head: 每个头的维度;总输出维度 = d_head * H"""def __init__(self, d_in, d_head, context_len, num_heads, dropout=0.1, qkv_bias=False):super().__init__()self.heads = nn.ModuleList([CausalSelfAttention(d_in, d_head, context_len, dropout, qkv_bias)for _ in range(num_heads)])self.out_proj = nn.Linear(d_head * num_heads, d_head * num_heads)def forward(self, x, pad_mask=None):outs = [h(x) for h in self.heads]            # [(B,N,d_head), ...]hcat = torch.cat(outs, dim=-1)               # (B,N,d_head*H)return self.out_proj(hcat)

优点:直观、好调试;缺点:Q/K/V 线性层被复制 H 次,算力/显存更贵。


第 4 步:多头注意力 · 写法 B(工程高效,一次性投影再“拆头”)

思路:Q/K/V 各只投影 一次d_out,随后把最后一维 reshape 成 (H, d_head),在 (B,H,N,d_head) 维度并行计算注意力,最后合并各头并过 out_proj

class MultiHeadAttention(nn.Module):def __init__(self, d_in, d_out, context_len, num_heads, dropout=0.1, qkv_bias=False):super().__init__()assert d_out % num_heads == 0, 'd_out 必须能被 num_heads 整除'self.H = num_headsself.d_out = d_outself.d_head = d_out // num_headsself.q = nn.Linear(d_in, d_out, bias=qkv_bias)self.k = nn.Linear(d_in, d_out, bias=qkv_bias)self.v = nn.Linear(d_in, d_out, bias=qkv_bias)self.out_proj = nn.Linear(d_out, d_out)self.drop = nn.Dropout(dropout)self.register_buffer('causal', torch.triu(torch.ones(context_len, context_len), 1).bool())def forward(self, x, pad_mask=None):B, N, _ = x.shape# 1) 一次性投影Q = self.q(x); K = self.k(x); V = self.v(x)          # (B,N,d_out)# 2) 拆头: (B,N,d_out)->(B,N,H,d_head)->(B,H,N,d_head)def split_heads(t):return t.view(B, N, self.H, self.d_head).transpose(1, 2).contiguous()Q, K, V = map(split_heads, (Q, K, V))                 # (B,H,N,Dh)# 3) 注意力分数 (B,H,N,N)scores = torch.matmul(Q, K.transpose(2,3)) / (self.d_head ** 0.5)# 因果遮挡 + 可选的 padding 遮挡scores = scores.masked_fill(self.causal[:N, :N], float('-inf'))if pad_mask is not None:scores = scores.masked_fill(pad_mask[:, None, None, :].bool(), float('-inf'))attn = torch.softmax(scores, dim=-1)attn = self.drop(attn)# 4) 聚合 (B,H,N,N)@(B,H,N,Dh)->(B,H,N,Dh)ctx = torch.matmul(attn, V)# 5) 合并各头 (B,N,d_out) 并做 out_projctx = ctx.transpose(1, 2).contiguous().view(B, N, self.d_out)return self.out_proj(ctx)

结论:写法 B 更省算/省显存,实际项目优先选它;写法 A 更适合教学与调试。


第 5 步:把 MHA 塞进最小 Transformer Block(Pre-LN 更稳)

class TinyBlock(nn.Module):def __init__(self, d_model, num_heads, context_len, dropout=0.1):super().__init__()self.ln1 = nn.LayerNorm(d_model)self.mha = MultiHeadAttention(d_model, d_model, context_len, num_heads, dropout)self.ln2 = nn.LayerNorm(d_model)self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model),nn.GELU(),nn.Linear(4*d_model, d_model),nn.Dropout(dropout))def forward(self, x, pad_mask=None):x = x + self.mha(self.ln1(x), pad_mask=pad_mask)  # 残差 1(Pre-LN)x = x + self.mlp(self.ln2(x))                     # 残差 2return x

备注:Pre-LN(先归一化再子层)在深网络里更易训练稳定。


第 6 步:最小语言模型前向与训练一步(含 label 右移)

class TinyLM(nn.Module):def __init__(self, vocab_size=50257, d_model=256, num_heads=8, context_len=256, n_layer=2, dropout=0.1):super().__init__()self.tok_emb = nn.Embedding(vocab_size, d_model)self.pos_emb = nn.Embedding(context_len, d_model)self.blocks  = nn.ModuleList([TinyBlock(d_model, num_heads, context_len, dropout)for _ in range(n_layer)])self.ln_f   = nn.LayerNorm(d_model)self.lm_head = nn.Linear(d_model, vocab_size, bias=False)  # tied 可进一步共享权重self.context_len = context_lendef forward(self, x):  # x: (B,N) token idsB, N = x.size()pos = torch.arange(N, device=x.device)h = self.tok_emb(x) + self.pos_emb(pos)           # (B,N,d_model)for blk in self.blocks:h = blk(h)                                    # (B,N,d_model)h = self.ln_f(h)logits = self.lm_head(h)                          # (B,N,vocab)return logits# === 一次训练步 ===
import torch.optim as optimdef one_training_step(model, x, optimizer):model.train()logits = model(x)                         # (B,N,V)# 语言建模的“右移标签”:预测 x[:, t] 的是原句的下一 token x[:, t+1]target = x[:, 1:].contiguous()            # (B,N-1)pred   = logits[:, :-1, :].contiguous()   # (B,N-1,V)loss = nn.CrossEntropyLoss()(pred.view(-1, pred.size(-1)), target.view(-1))optimizer.zero_grad(set_to_none=True)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()return loss.item()# 伪造一批输入(实际应来自 DataLoader)
B, N = 2, 32
x_batch = torch.randint(0, 1000, (B, N))
model = TinyLM(vocab_size=1000, d_model=128, num_heads=4, context_len=N, n_layer=2)
opt = optim.AdamW(model.parameters(), lr=3e-4)
print('loss:', one_training_step(model, x_batch, opt))

小贴士:真实语料可用滑窗切片(见文末“数据加载最小版”)。


第 7 步:等价性验证 / 输出维度=2 / 参数量统计

7.1 Linear 版 vs 手写矩阵版:输出可逐元素一致

import torch, torch.nn as nn, torch.nn.functional as Fd_in, d_out = 3, 2
class SA_v1(nn.Module):  # 手写参数矩阵def __init__(self):super().__init__()self.Wq = nn.Parameter(torch.randn(d_in, d_out))self.Wk = nn.Parameter(torch.randn(d_in, d_out))self.Wv = nn.Parameter(torch.randn(d_in, d_out))def forward(self, x):Q, K, V = x@self.Wq, x@self.Wk, x@self.WvA = (Q @ K.t()) / (K.size(-1) ** 0.5)return F.softmax(A, -1) @ Vclass SA_v2(nn.Module):  # Linear 版本def __init__(self):super().__init__()self.q, self.k, self.v = nn.Linear(d_in,d_out,False), nn.Linear(d_in,d_out,False), nn.Linear(d_in,d_out,False)def forward(self, x):Q, K, V = self.q(x), self.k(x), self.v(x)A = (Q @ K.t()) / (K.size(-1) ** 0.5)return F.softmax(A, -1) @ Vx = torch.tensor([[0.43,0.15,0.89],[0.55,0.87,0.66],[0.22,0.58,0.33]], dtype=torch.float)
sa1, sa2 = SA_v1(), SA_v2()
# 把 sa2 的权重转置复制到 sa1(Linear 权重形状是 (out,in))
with torch.no_grad():sa1.Wq.copy_(sa2.q.weight.T)sa1.Wk.copy_(sa2.k.weight.T)sa1.Wv.copy_(sa2.v.weight.T)
print(torch.allclose(sa1(x), sa2(x)))  # True

7.2 我想要最后维度=2怎么办?

两头注意力、每头 d_head=1,拼接后自然=2:

B,N,d_in = 2,6,3
x = torch.randn(B,N,d_in)
mha2 = MultiHeadAttention_Wrapper(d_in, d_head=1, context_len=N, num_heads=2, dropout=0.0)
print(mha2(x).shape)  # torch.Size([2, 6, 2])

7.3 参数量统计(以 d_model=768, heads=12 为例)

def count_params(m):return sum(p.numel() for p in m.parameters() if p.requires_grad)mha = MultiHeadAttention(768, 768, context_len=1024, num_heads=12, dropout=0.0)
print('MHA params ~', count_params(mha))  # 约 2.36M(数量级把握够用)

说明:大模型的参数大头通常在 前馈 MLP 子层,MHA 并非主要来源。


第 8 步:性能与工程选项

  • 内置模块torch.nn.MultiheadAttention 可直接用(注意它默认是 (N,B,E) 维度顺序)。
  • SDPAtorch.nn.functional.scaled_dot_product_attention 可以少写很多样板,还能自动选用更优实现。
  • FlashAttention:当满足显卡/精度要求时可显著提速与省显存(长度较长时收益更明显)。
  • qkv_bias:开启能提高拟合灵活性,但轻微增参;很多开源配置里对 qv 开启 bias。
  • dropout 放哪:常见放在注意力权重上;有时也在 ctx 上加(择一即可)。
  • 位置编码:本文用可学习绝对位置;工业界常用 RoPE 等旋转位置编码以提升长程外推能力。
  • MHA 工厂函数:在项目里用工厂模式便于切换实现:
def make_mha(kind: str, *, d_in, d_out, num_heads, context_len, dropout=0.0, qkv_bias=False):kind = kind.lower()if kind in {'wrapper'}:return MultiHeadAttention_Wrapper(d_in, d_out//num_heads, context_len, num_heads, dropout, qkv_bias)if kind in {'split','mha','fast'}:return MultiHeadAttention(d_in, d_out, context_len, num_heads, dropout, qkv_bias)if kind in {'torch'}:# 适配 PyTorch 原生 MHA(注意维度和 mask 的差异,略)raise NotImplementedError('此分支留给 torch.nn.MultiheadAttention 的适配')raise ValueError(f'unknown kind: {kind}')

常见坑位排查清单(把坑踩在我身上)

  1. mask 放错时机:必须在 softmax 之前 masked_fill(-inf);否则每行不再和为 1。
  2. 形状不匹配:重点核对 (B,H,N,Dh) @ (B,H,Dh,N) -> (B,H,N,N)view/transpose 后记得 .contiguous()
  3. d_out % heads != 0:保证可整除;否则没法均分到每个头。
  4. 越界的 positionpos_emb 的表长要 ≥ context_len,否则索引越界。
  5. 精度/溢出:长序列时 scores 可能很大,记得除以 √d_k;必要时用 bfloat16/amp 并注意数值安全。
  6. padding 遮挡方向:遮挡“列”而不是“行”,即别人不能看到我的 pad
  7. 训练不收敛:检查学习率、梯度裁剪、权重衰减、dropout、是否忘了 LayerNorm,以及数据是否正确右移。
  8. 推理速度慢:用缓存 K/V(增量解码),或选择 SDPA/Flash;批量化与张量化比 Python 循环重要得多。

数据加载最小版(滑窗切片)

import tiktoken
from torch.utils.data import Dataset, DataLoader
import torchclass GPTDatasetV1(Dataset):def __init__(self, txt, tokenizer, max_length, stride):self.input_ids, self.target_ids = [], []tok = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})for i in range(0, len(tok) - max_length, stride):x = tok[i:i+max_length]y = tok[i+1:i+max_length+1]self.input_ids.append(torch.tensor(x))self.target_ids.append(torch.tensor(y))def __len__(self):return len(self.input_ids)def __getitem__(self, idx):return self.input_ids[idx], self.target_ids[idx]def create_dataloader(txt, batch_size=4, max_length=128, stride=128, shuffle=True):tokenizer = tiktoken.get_encoding('gpt2')dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

结语

  • 多头注意力的精髓:点积打分 + 缩放 + softmax + 遮挡 + 多视角并行
  • 工程上:优先“一次性投影再拆头”,把 mask 放在 softmax 前,保证 形状/维度右移标签 正确。
  • 用等价性测试和参数统计,快速验证你的实现是否靠谱。
http://www.dtcms.com/a/334628.html

相关文章:

  • 【基础-判断】可以通过ohpm uninstall 指令下载指定的三方库
  • 中国教育信息化演进历程与发展趋势研究报告
  • Bash常用操作总结
  • 解决html-to-image在 ios 上dom里面的图片不显示出来
  • 《Python 单例模式(Singleton)深度解析:从实现技巧到争议与最佳实践》
  • 【自动化运维神器Ansible】Ansible逻辑运算符详解:构建复杂条件判断的核心工具
  • Manus AI与多语言手写识别的技术突破与行业变革
  • c#Blazor WebAssembly在网页中多线程计算1000万次求余
  • aws(学习笔记第五十一课) ECS集中练习(3)
  • 基于W55MH32Q-EVB 实现 HTTP 服务器配置 OLED 滚动显示信息
  • qsort实现数据排序
  • cuda编程笔记(15)--使用 CUB 和 atomicAdd 实现 histogram
  • PMP-项目管理-十大知识领域:进度管理-制定时间表、优化活动顺序、控制进度
  • 进程替换:从 “改头换面” 到程序加载的底层逻辑
  • 【深度学习计算性能】05:多GPU训练
  • TypeScript快速入门
  • MCP 大模型的扩展坞
  • 洛谷P1595讲解(加强版)+错排讲解
  • php版的FormCreate使用注意事项
  • 基于单片机的防酒驾系统设计
  • NY243NY253美光固态闪存NY257NY260
  • 24. async await 原理是什么,会编译成什么
  • 惠普声卡驱动win10装机完成检测不到声卡
  • Three.js 材质系统深度解析
  • 云原生俱乐部-RH124知识点总结(1)
  • 【CV 目标检测】Fast RCNN模型①——与R-CNN区别
  • 解锁 AI 音乐魔法,三款音乐生成工具
  • 《P4180 [BJWC2010] 严格次小生成树》
  • 服务器配置开机自启动服务
  • 基于深度强化学习的多用途无人机路径优化研究