当前位置: 首页 > news >正文

Transformer之多头注意力机制和位置编码(二)

Transformer之多头注意力机制和位置编码(二)

文章目录

  • Transformer之多头注意力机制和位置编码(二)
  • 一、 多头注意力(Multi-Head Attention)
    • 案例
  • 二、位置编码(Positional Encoding)
    • 2.1 固定正弦/余弦
    • 2.2 可学习编码


一、 多头注意力(Multi-Head Attention)

先把整句映射成 Q、K、V,再按列切分 → 多头并行计算

核心公式

headh=Attention(QWhQ,KWhK,VWhV)MultiHead(Q,K,V)=[head1;…;headH]WO\begin{aligned} \mathrm{head}_h &= \mathrm{Attention}(Q W^Q_h,\,K W^K_h,\,V W^V_h) \\[2pt] \mathrm{MultiHead}(Q,K,V) &= [\mathrm{head}_1;\dots;\mathrm{head}_H]W^O \end{aligned} headhMultiHead(Q,K,V)=Attention(QWhQ,KWhK,VWhV)=[head1;;headH]WO

步骤速览

  1. 分头映射dim=512, head_num=8 → 每头 d_k=64
    head_num, d_k = 8, dim // 8
    W_Q = nn.Linear(dim, dim)
    Q_h = W_Q(x).view(b, seq, head_num, d_k).transpose(1, 2)  # [b, 8, seq, 64]
    # K_h, V_h 同理
    
  2. 并行注意力(缩放点积 + softmax)
  3. 拼接 + 线性
    out = out.transpose(1, 2).contiguous().view(b, seq, dim)
    out = nn.Linear(dim, dim)(out)          # 最终输出 [b, seq, 512]
    
2 个头 = 2 组独立 (Q,K,V) 子空间

表达能力
不同头关注不同模式(句法、语义、指代…),组合后更灵活。

案例

下面给出可直接复用的“多头注意力”最小实现。

特点

  • nn.Module 封装,方便后续放进 nn.Sequential 或 Transformer;
  • 逐行中文注释,一眼看懂每一步在干什么;
  • 输出维度与原句向量一致 [batch, seq_len, dim],后续可继续堆叠。
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):"""简化版多头注意力dim        : 模型总维度(你的例子里是 256)head_num   : 头数(你的例子里是 16)输出形状   : [batch, seq_len, dim]"""def __init__(self, dim: int = 256, head_num: int = 16):super().__init__()# assert :“必须保证 dim 能被 head_num 整除,否则就报错。”assert dim % head_num == 0 self.dim = dimself.head_num = head_numself.d_k = dim // head_num          # 每个头的维度# 3 个线性层一次性把 Q/K/V 投影出来(比 ModuleList 更简洁)self.W_q = nn.Linear(dim, dim)self.W_k = nn.Linear(dim, dim)self.W_v = nn.Linear(dim, dim)# 最后的输出线性变换self.W_o = nn.Linear(dim, dim)def forward(self, x):"""x : [batch, seq_len, dim]return : 与 x 形状相同"""batch, seq_len, _ = x.shape# 1) 线性投影 → [batch, seq_len, dim]Q = self.W_q(x)K = self.W_k(x)V = self.W_v(x)# 2) 拆成多头 → [batch, head_num, seq_len, d_k]def reshape(t):return t.view(batch, seq_len, self.head_num, self.d_k).transpose(1, 2)Q = reshape(Q)   # [batch, head_num, seq_len, d_k]K = reshape(K)V = reshape(V)# 3) 缩放点积注意力scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  # [B,h,seq,seq]attn = F.softmax(scores, dim=-1)                                      # 归一化out = torch.matmul(attn, V)                                           # [B,h,seq,d_k]# 4) 合并多头 → [batch, seq_len, dim]out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.dim)# 5) 最后线性变换return self.W_o(out)# ------------------ 测试 ------------------
if __name__ == "__main__":sentences = ["i am an NLPer"]vocab = sorted({w for sent in sentences for w in sent.split()})word2idx = {w: i for i, w in enumerate(vocab)}indices = torch.tensor([[word2idx[w] for w in sent.split()] for sent in sentences])dim = 256embedding = nn.Embedding(len(vocab), dim)x = embedding(indices)                  # [1, 4, 256]mha = MultiHeadAttention(dim=256, head_num=16)y = mha(x)                              # [1, 4, 256]print("输入:", x.shape)print("输出:", y.shape)

运行结果

输入: torch.Size([1, 4, 256])
输出: torch.Size([1, 4, 256])

二、位置编码(Positional Encoding)

自注意力本身“无顺序”,需显式注入位置信号。

2.1 固定正弦/余弦

给定位置 i 与维度 2j / 2j+1
PE(i,2j)=sin⁡⁣(i100002j/d)PE(i,2j+1)=cos⁡⁣(i100002j/d)\begin{aligned} PE(i,2j) &= \sin\!\left(\dfrac{i}{10000^{2j/d}}\right) \\[4pt] PE(i,2j+1) &= \cos\!\left(\dfrac{i}{10000^{2j/d}}\right) \end{aligned} PE(i,2j)PE(i,2j+1)=sin(100002j/di)=cos(100002j/di)

  • 例子:句子 <BOS> 我 喜欢 自然语言 处理(N=5, d=512)
    计算得到的 5×512 位置矩阵与词嵌入逐位相加即可。
    说明:

  • iii 是序列中位置的索引(从 000 开始)。

  • jjj 是词向量的维度索引(从 000d/2−1d/2 - 1d/21)。

  • 100001000010000 是一个超参数,用于控制频率的衰减。

句子长度为 555,编码向量维数 D=4D=4D=4
  • 外推特性
    已知 PE(pos)PE(pos)PE(pos) 可线性组合得到 PE(pos+k)PE(pos+k)PE(pos+k),模型可处理比训练集更长的句子。
505050 个词嵌入,维度 512512512 的位置编码热力图

2.2 可学习编码

直接把位置当 token 训,表现好但依赖最大长度超参。


小结
多头 = “多组独立子空间”并行注意力;
位置编码 = “给并行计算加上顺序感”。二者配合让 Transformer 既能并行又能保持序列有序。

http://www.dtcms.com/a/331072.html

相关文章:

  • Canon PowerShot D30相机 CHDK 固件 V1.4.1
  • 5.Ansible-playbook-模块介绍(知识点补充)
  • 【Postgresql】实现 PostgreSQL 全量审计日志:记录所有 SQL 操作及来源
  • 【C++】细说继承(2w字详解)
  • ROS机器人云实践案例博客建议和范文-AI版本
  • imx6ull-驱动开发篇24——Linux 中断API函数
  • MATLAB绘制各种心形曲线
  • window显示驱动开发—在混合系统中使用跨适配器资源
  • nginx-集成prometheus监控(k8s)
  • GitHub 热榜项目 - 日榜(2025-08-14)
  • 一、linux内存管理学习(1):物理内存探测
  • 京东商品列表API开发指南
  • OpenCV对椒盐处理后的视频进行均值滤波处理
  • Opencv 边界填充 图像运算 阈值处理 和图像平滑处理
  • 文件上传接口接收不到文件入参
  • 题解:P4777 【模板】扩展中国剩余定理(EXCRT)
  • Qt项目查找依赖库打包
  • IDEA、Pycharm、DataGrip等激活破解冲突问题解决方案之一
  • Springboot项目重启后Session依旧存在
  • Python包性能优化与并发编程:构建高性能应用的核心技术(续)
  • 轻量级开源全文搜索引擎:Manticore Search 入门介绍
  • C++基础(①入门教程)
  • 本地jar导入到本地仓科和远程仓库
  • Maven学习笔记
  • 92、23种设计模式-单例模式
  • 项目日志框架与jar中日志框架冲突 解决
  • 《多级缓存架构设计与实现全解析》
  • 自动化测试|持续集成Git使用详解
  • label studio 服务器端打开+xshell端口转发设置
  • 01数据结构-最短路径Dijkstra