NLP:Transformer之多头注意力(特别分享4)
本文目录:
- 一、先建立“分头-并行-拼接”直觉
- 二、符号与形状(batch-first 写法,易读)
- 三、公式(单头 → 多头)
- 四、Python 裸写(无库,带 mask 可跑)
- 五、例子:为什么“多头”比“一头”好?
- 六、与 Efficient 变体关系
- 七、总结
前言:今晚多分享一篇关于Transformer多头注意力的文。
一、先建立“分头-并行-拼接”直觉
Input X ──► Linear(Q,K,V) ──► Split heads ──► Scaled Dot-Product each head│▼Output Z ◄── Concat heads ◄── Parallel softmax(QK^T/sqrt(d_k))V
一句话:把原来的 (Q,K,V) 拆成 h 个“小通道”并行算注意力,最后再拼回来。
二、符号与形状(batch-first 写法,易读)
张量 | 形状 | 含义 |
---|---|---|
X | (B, L, d_model) | 输入序列,L 个 token |
W_Q | (d_model, d_k·h) | 查询映射矩阵 |
W_K | (d_model, d_k·h) | 键映射矩阵 |
W_V | (d_model, d_v·h) | 值映射矩阵 |
Q/K/V | (B, h, L, d_k) | 拆头后,每个头维度 d_k = d_model / h |
O | (B, h, L, d_v) | 单头注意力输出 |
W_O | (d_v·h, d_model) | 输出投影,把拼接结果压回 d_model |
Z | (B, L, d_model) | 最终输出 |
三、公式(单头 → 多头)
1. 单头 Scaled Dot-Product
Attention(Q,K,V) = softmax( (QK^T) / √d_k ) V
2. 多头
head_i = Attention( XW_Q[:,i], XW_K[:,i], XW_V[:,i] )
MultiHead(X) = Concat(head_1,…,head_h) W_O
四、Python 裸写(无库,带 mask 可跑)
import torch, mathdef multi_head_attention(x, W_q, W_k, W_v, W_o, h, mask=None):B, L, d = x.shaped_k = d // h# 1. 线性投影 + 拆头q = (x @ W_q).view(B, L, h, d_k).transpose(1, 2) # (B,h,L,d_k)k = (x @ W_k).view(B, L, h, d_k).transpose(1, 2)v = (x @ W_v).view(B, L, h, d_k).transpose(1, 2)# 2. 缩放点积scores = (q @ k.transpose(-2, -1)) / math.sqrt(d_k) # (B,h,L,L)if mask is not None: # 下三角 or padding maskscores = scores.masked_fill(mask==0, -1e9)attn = torch.softmax(scores, dim=-1)out = attn @ v # (B,h,L,d_k)# 3. 拼接 + 输出投影out = out.transpose(1, 2).contiguous().view(B, L, d)return out @ W_o
五、例子:为什么“多头”比“一头”好?
任务:翻译 “The animal didn’t cross the street because it was too tired.”
要让 it
指 animal
而不是 street
。
- 单头:只能抓一种相似度,可能把
it
和street
的“位置接近”当成高权重。 - 多头:
head-1 专注“语法位置” → 发现it
与animal
主语对齐;
head-2 专注“语义相似” → 发现it
与animal
embeddings 更接近;
拼接后投票,错误概率显著下降。
→ 相当于ensemble of attention mechanisms,每个头学不同的子空间表示。
六、与 Efficient 变体关系
版本 | 改动 | 效果 |
---|---|---|
Multi-Query Attention (MQA) | 所有头共享同一 K/V | 推理显存↓30–50%,速度↑ |
Grouped-Query Attention (GQA) | 分组共享 K/V | 平衡质量与速度,Llama-2/3 用 |
FlashAttention | 分块+重计算,O(N²)→O(N) 显存 | 长序列神器,训练提速 2–4× |
七、总结
Multi-Head Attention 就是把单头注意力复制 h 次,让每一路专注不同子空间,再 ensemble 结果;
实现上只是** reshape + 并行矩阵乘 + concat**,却成为 Transformer 表达能力的核心来源。
本文分享到此结束。