当前位置: 首页 > news >正文

手写multi-head Self-Attention,各个算子详细注释版

文章目录

    • MultiHeadAttentionFormal的实现
    • 操作详解
      • 1. 🔍 attention_mask
      • 2. 🔍 matmul
        • ✅ 其他实现方式
          • 1. 使用 `@` 运算符(推荐简洁写法)
          • 2. 使用 `torch.einsum()`(爱因斯坦求和约定)
          • 3. 使用 `torch.bmm()`(批量矩阵乘法)
          • 4. 使用 `unsqueeze` + `squeeze` 控制维度(兼容高维)
          • 5. 使用 `F.linear()` 实现投影(不常用)
        • 📌 对比总结表
        • 💡 示例对比(均等效)
      • 3. 🔍 transpose
        • 📌 定义
        • 🧠 在多头注意力中的典型应用场景
        • ✅ 其他实现方式
          • 1. 使用 `permute(*dims)` —— 更灵活的维度重排
          • 2. 使用 `swapaxes(dim0, dim1)` —— 与 transpose 等效
        • 📌 总结对比表
        • 💡 示例说明
        • 🛠 实际应用建议
      • 4. 🔍 view()
        • 🔄 其他等效实现方式
          • 1. `torch.reshape(tensor, shape)`
          • 2. 使用 `flatten(start_dim, end_dim)` 合并维度
          • 3. 使用 `einops.rearrange`(推荐用于可读性)
        • ✅ 总结对比
        • 💡 实际应用建议
      • 5. 🔍 masked_fill()
        • 🧠 函数定义
        • 示例解析
        • ✅ 实际案例演示
        • ⚠️ 注意事项
        • 💡 应用场景
        • ✅ 总结
        • 📌 最佳实践建议
      • 参考材料

MultiHeadAttentionFormal的实现

import torch
import torch.nn as nn
import mathclass MultiHeadAttentionFormal(nn.Module):def __init__(self, hidden_dim, head_num, attention_dropout=0.1):super().__init__()self.hidden_dim = hidden_dimself.head_num = head_numself.head_dim = hidden_dim // head_num  # head_num * head_dim = hidden_dimself.q_proj = nn.Linear(hidden_dim, hidden_dim)  # (hidden_dim, head_dim * head_num)self.k_proj = nn.Linear(hidden_dim, hidden_dim)self.v_proj = nn.Linear(hidden_dim, hidden_dim)self.output = nn.Linear(hidden_dim, hidden_dim)self.attention_dropout = nn.Dropout(attention_dropout)def forward(self, x, attention_mask=None):# X (batch_size, seq_len, hidden_dim)batch_size, seq_len, _ = x.shape# Q/K/V的shape: (batch_size, seq_len, hidden_dim)Q = self.q_proj(x)K = self.k_proj(x)V = self.v_proj(x)# (batch_size, seq_len, hidden_dim),其中 hidden_dim = head_num * head_dim# -> (batch_size, seq_len, head_num, head_dim)# -> (batch_size, head_num, seq_len, head_dim)q_state = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)k_state = K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)v_state = V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)# k_state的转置# (batch_size, head_num, seq_len, head_dim)# -> (batch_size, head_num, head_dim, seq_len)# 相乘的结果,shape为(batch_size, head_num, seq_len, seq_len)atten_weight = torch.matmul(q_state, k_state.transpose(-2, -1)) / math.sqrt(self.head_dim)print("stage1, atten_weight.shape: ", atten_weight.shape)if attention_mask is not None:atten_weight = atten_weight.masked_fill(attention_mask==0, float("-inf"))print("stage2, atten_weight.shape: ", atten_weight.shape)atten_weight = torch.softmax(atten_weight, dim=-1)print("stage3, atten_weight.shape: ", atten_weight.shape)atten_weight = self.attention_dropout(atten_weight)print("stage4, atten_weight.shape: ", atten_weight.shape)# atten_weight: (batch_size, head_num, seq_len, seq_len)# v_state: (batch_size, head_num, seq_len, head_dim)# => (batch_size, head_num, seq_len, head_dim)output_mid = torch.matmul(atten_weight, v_state)print("stage1, output_mid.shape: ", output_mid.shape, "v_state.shape: ", v_state.shape)# transpose后,张量的内存可能变得不连续,所以需要用contiguous把内存连续化;view()、reshape()、flatten()、torch.nn.Linear、torch.matmul 等操作对输入张量有连续性的要求。output_mid = output_mid.transpose(1, 2).contiguous()print("stage2, output_mid.shape: ", output_mid.shape)output_mid = output_mid.view(batch_size, seq_len, self.hidden_dim)print("stage3, output_mid.shape: ", output_mid.shape)output = self.output(output_mid)return outputattention_mask = torch.tensor([[1,1],[1,0],[1,0]]
).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)# batch_size, seq_len, hidden_dim
X = torch.rand(3, 2, 128)net = MultiHeadAttentionFormal(128, 8)  # hidden_dim = 128, head_num = 8 -> head_dim = 16
net(X, attention_mask)

操作详解

1. 🔍 attention_mask

首先是创建一个随机张量,shape为(batch_size, seq_len)

attention_mask = torch.tensor([[1, 1],[1, 0],[1, 0]
])这是一个形状为 (3, 2) 的张量。
每一行表示一个样本(batch)的 attention mask:
1 表示该位置是有效的;
0 表示该位置是 padding,需要被屏蔽

然后增加维度

attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)---------------------
tensor([[[[1, 1]]],[[[1, 0]]],[[[1, 0]]]])第一次 unsqueeze(1):增加第 1 维度(head_num),形状变为 (3, 1, 2)
第二次 unsqueeze(2):增加第 2 维度(seq_len),形状变为 (3, 1, 1, 2)
此时维度含义为:(batch_size, 1, 1, seq_len)
注意:此时还没有考虑 head_num,只是准备好了 mask 的基本结构

现在扩展到head_num

attention_mask = attention_mask.expand(3, 8, 2, 2)
-------------
tensor([[[[1, 1],          [1, 1]], # 头1[[1, 1],          [1, 1]], # 头2[[1, 1],          [1, 1]], # 头3[[1, 1],          [1, 1]], # 头4[[1, 1],          [1, 1]], # 头5[[1, 1],          [1, 1]], # 头6[[1, 1],          [1, 1]], # 头7[[1, 1],          [1, 1]], # 头8][[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]]],[[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]],[[1, 0],          [1, 0]]]])expand() 是 PyTorch 中用于广播张量的方法,不会复制数据,而是共享内存。
将 (3, 1, 1, 2) 扩展为 (3, 8, 2, 2)3:batch size
8:head_num,每个 head 都使用相同的 mask
2:query 的序列长度(seq_len)
2:key/value 的序列长度(seq_len)attention_mask.shape为(batch_size, head_num, seq_len, seq_len)

2. 🔍 matmul

以这行代码为例:

output_mid = torch.matmul(atten_weight, v_state)

其中:

  • atten_weight.shape = (batch_size, head_num, seq_len, seq_len),即注意力权重矩阵(通常是 softmax 后的结果)
  • v_state.shape = (batch_size, head_num, seq_len, head_dim),即 value 的状态

这个操作本质上是将 attention weight 与 value 进行矩阵乘法,得到加权后的输出。


✅ 其他实现方式
1. 使用 @ 运算符(推荐简洁写法)
output_mid = atten_weight @ v_state
  • 等价于 torch.matmul
  • 更加 Pythonic,代码更简洁
  • 支持广播机制

2. 使用 torch.einsum()(爱因斯坦求和约定)
output_mid = torch.einsum('bhij,bhjd->bhid', atten_weight, v_state)
  • 非常灵活,适用于多头注意力、交叉注意力等复杂结构
  • 显式控制每个维度的运算规则,可读性略差但表达能力更强
  • 在调试或构建复杂模型时非常有用

3. 使用 torch.bmm()(批量矩阵乘法)
# 将 batch 和 head 合并成一个大 batch 维度
batch_size, head_num, seq_len, _ = atten_weight.shape
atten_weight_flat = atten_weight.view(-1, seq_len, seq_len)  # shape: (B*H, T, T)
v_state_flat = v_state.view(-1, seq_len, head_dim)            # shape: (B*H, T, D)output_flat = torch.bmm(atten_weight_flat, v_state_flat)      # shape: (B*H, T, D)
output_mid = output_flat.view(batch_size, head_num, seq_len, head_dim)
  • 只支持 3D 张量,不支持自动广播
  • 性能接近 matmul,但需要手动处理维度变形

4. 使用 unsqueeze + squeeze 控制维度(兼容高维)
output_mid = torch.matmul(atten_weight.unsqueeze(-2), v_state.unsqueeze(-1)
).squeeze(-1)
  • 通过添加/删除维度来精确控制 matmul 操作维度
  • 适合在图像、视频等 attention 中使用

5. 使用 F.linear() 实现投影(不常用)

虽然不是标准做法,但如果 atten_weight 是某种投影权重矩阵,也可以用线性层模拟。但在 attention 中通常不适用。


📌 对比总结表
方法输入要求是否支持 batch是否支持 broadcasting推荐用于 Attention
torch.matmul任意维度✅✅✅
@任意维度✅✅✅(简洁)
torch.einsum需要指定索引✅✅✅(多头)
torch.bmm必须为 3D✅(简单 attention)
unsqueeze + matmul手动控制维度✅(特殊场景)

💡 示例对比(均等效)
# 原始写法
output_mid = torch.matmul(atten_weight, v_state)# 使用 @ 符号
output_mid = atten_weight @ v_state# 使用 einsum
output_mid = torch.einsum('bhij,bhjd->bhid', atten_weight, v_state)# 使用 bmm(需 flatten + reshape)
batch_size, head_num, seq_len, _ = atten_weight.shape
atten_weight_flat = atten_weight.view(-1, seq_len, seq_len)
v_state_flat = v_state.view(-1, seq_len, head_dim)
output_flat = torch.bmm(atten_weight_flat, v_state_flat)
output_mid = output_flat.view(batch_size, head_num, seq_len, -1)

3. 🔍 transpose

output_mid = output_mid.transpose(1, 2)

这行代码的作用是交换张量的第 1 维和第 2 维。用于处理多头注意力(Multi-Head Attention)中张量形状的调整。


📌 定义
torch.Tensor.transpose(dim0, dim1) -> Tensor
  • 功能:返回一个新的张量,其中指定的两个维度被交换。
  • 参数
    • dim0: 第一个维度
    • dim1: 第二个维度

⚠️ 注意:这个操作不会复制数据,而是返回原始张量的一个视图(view)。如果后续需要使用 view()reshape(),可能需要调用 .contiguous() 来确保内存连续。


🧠 在多头注意力中的典型应用场景
# 假设 input shape: (batch_size, head_num, seq_len, head_dim)
output_mid = output_mid.transpose(1, 2)

原始形状:

output_mid.shape = (batch_size, head_num, seq_len, head_dim)

转置后形状:

output_mid.shape = (batch_size, seq_len, head_num, head_dim)

然后一般会进行 view() 操作来合并 head_numhead_dim,得到最终输出:

output_mid = output_mid.contiguous().view(batch_size, seq_len, -1)
# 最终 shape: (batch_size, seq_len, hidden_dim)

这是将多头注意力结果重新拼接回原始隐藏层大小的关键步骤。


✅ 其他实现方式

除了使用 transpose(),还有以下几种方法可以实现类似效果:

1. 使用 permute(*dims) —— 更灵活的维度重排
output_mid = output_mid.permute(0, 2, 1, 3)
  • permute() 可以一次重排多个维度
  • 示例前后的 shape 对应关系:
    # 原 shape: (batch_size, head_num, seq_len, head_dim)
    # 新 shape: (batch_size, seq_len, head_num, head_dim)
    

✅ 推荐用于更复杂的维度变换场景


2. 使用 swapaxes(dim0, dim1) —— 与 transpose 等效
output_mid = output_mid.swapaxes(1, 2)
  • transpose() 功能相同
  • 更语义化,适合阅读时强调“交换”而非“转置”


📌 总结对比表
方法支持任意维是否返回 view是否支持链式操作推荐用途
transpose()❌ 仅限两个维度简单交换两个维度
permute()✅ 多维支持高阶张量维度重排(推荐)
swapaxes()强调“交换”,语义更强

💡 示例说明

假设输入为:

output_mid.shape = (3, 8, 2, 16)  # batch_size=3, head_num=8, seq_len=2, head_dim=16

使用 transpose(1, 2)

output_mid = output_mid.transpose(1, 2)
# output_mid.shape 
(batch_size, head_num, seq_len, head_dim)
=> (batch_size, seq_len, head_num, head_dim)
(3, 8, 2, 16)
=> (3, 2, 8, 16)

使用 permute(0, 2, 1, 3)

output_mid = output_mid.permute(0, 2, 1, 3)
# output_mid.shape => (3, 2, 8, 16)

两者等价,但 permute() 更具通用性。


🛠 实际应用建议
  • 如果只是交换两个维度 → transpose()
  • 如果涉及多维重排 → permute()
  • 如果要合并/拆分某些维度 → permute() + contiguous() + view()

4. 🔍 view()

在 PyTorch 中,view() 是一个用于 改变张量形状(reshape) 的函数。它不会修改张量的数据,只是重新解释其形状。

语法:

tensor.view(shape)

示例代码:

output_mid = output_mid.view(batch_size, seq_len, self.hidden_dim)

前提条件:

  • output_mid 当前的 shape 是 (batch_size, seq_len, head_num, head_dim)
  • head_num * head_dim == hidden_dim
  • 所以 view 后变为 (batch_size, seq_len, hidden_dim)

作用:
将多头注意力中每个 head 的输出拼接起来,恢复成原始的 hidden_dim 维度。

比如:

# 假设 batch_size=3, seq_len=2, head_num=8, head_dim=16
output_mid.shape = (3, 8, 2, 16)  # transpose + contiguous 后
output_mid = output_mid.view(3, 2, 128)  # 8*16 = 128

⚠️ 注意:使用 view() 前必须保证张量是连续的(contiguous),否则会报错。所以前面通常有 .contiguous() 调用。


🔄 其他等效实现方式

除了 view(),还有以下几种方式可以实现类似功能:

1. torch.reshape(tensor, shape)

view() 类似,但更灵活,可以在非连续内存上运行。

output_mid = output_mid.reshape(batch_size, seq_len, self.hidden_dim)

✅ 推荐使用这个替代 view(),因为不需要关心是否是连续内存。


2. 使用 flatten(start_dim, end_dim) 合并维度
output_mid = output_mid.transpose(1, 2).flatten(start_dim=2, end_dim=3)

这相当于把第 2 和第 3 维合并,效果等同于 reshape 或 view。


3. 使用 einops.rearrange(推荐用于可读性)

来自 einops 库(einop库安装及介绍),提供更直观的维度操作方式:

from einops import rearrangeoutput_mid = rearrange(output_mid, 'b h s d -> b s (h d)')

优点:

  • 更易读
  • 不需要关心是否连续
  • 可扩展性强(支持更多复杂变换)

✅ 总结对比
方法是否要求连续易读性灵活性推荐场景
view()❌ 必须连续⬇️ 差⬇️ 一般小规模调试
reshape()✅ 不要求⬆️ 好⬆️ 强通用替换 view
flatten()✅ 不要求⬆️ 好⬆️ 强多维合并
einops.rearrange()✅ 不要求⬆️ 很好⬆️ 非常强工程项目

💡 实际应用建议

如果你在写正式项目或模型工程化,推荐使用:

from einops import rearrangeoutput_mid = rearrange(output_mid, 'b h s d -> b s (h d)')

或者安全版本(不依赖连续内存):

output_mid = output_mid.transpose(1, 2)
output_mid = output_mid.flatten(2)  # (b, s, h*d)

这样不仅代码清晰,也避免了对 .contiguous() 的依赖问题。

5. 🔍 masked_fill()

在 PyTorch 中,masked_fill() 是一个非常常用的函数,用于 根据布尔掩码(mask)对张量的某些位置进行填充。它常用于 NLP 任务中,比如 Transformer 模型中的 attention mask 处理。


🧠 函数定义
torch.Tensor.masked_fill(mask, value)

参数说明:

  • mask: 一个布尔类型的张量(True/False),形状必须与原张量相同。
  • value: 要填充的值,可以是标量或广播兼容的张量。

行为:

  • 对于 mask 中为 True 的位置,将原张量对应位置的值替换为 value
  • False 的位置保持不变。

示例解析
atten_weight = atten_weight.masked_fill(attention_mask == 0, float("-inf"))

解释:

  1. attention_mask == 0
    • 这是一个布尔操作,生成一个和 attention_mask 形状相同的布尔张量。
    • 所有等于 0 的位置变成 True,表示这些位置是 pad 或无效 token,不应该参与 attention 计算。
  2. float("-inf")
    • 将这些被 mask 的位置填入负无穷大。
    • 在后续 softmax 中,exp(-inf) 会变成 0,从而实现“忽略这些位置”的效果。

✅ 实际案例演示

输入示例:

import torch# 原始 attention 权重 (模拟)
atten_weight = torch.tensor([[0.1, 0.2, 0.3, 0.4],[0.5, 0.6, 0.7, 0.8]
])# attention mask (pad 位置为 0)
attention_mask = torch.tensor([[1, 1, 0, 0],[1, 0, 0, 0]
])# 应用 masked_fill
atten_weight = atten_weight.masked_fill(attention_mask == 0, float("-inf"))
print(atten_weight)

输出结果:

tensor([[ 0.1000,  0.2000,   -inf,   -inf],[ 0.5000,   -inf,   -inf,   -inf]])

后续 softmax 结果:

import torch.nn.functional as F
F.softmax(atten_weight, dim=-1)

输出:

tensor([[0.4621, 0.5379, 0.0000, 0.0000],[1.0000, 0.0000, 0.0000, 0.0000]])

可以看到,mask 为 0 的位置在 softmax 后变成了 0,不会影响最终注意力分布。


⚠️ 注意事项
  1. mask 张量的 shape 必须与目标张量一致

    • 如果你有一个 (batch_size, seq_len) 的 mask,而 atten_weight(batch_size, head_num, seq_len, seq_len),你需要通过 unsqueezeexpand 调整 mask 的维度。
    • 示例:
      attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # -> (batch_size, 1, 1, seq_len)
      attention_mask = attention_mask.expand(batch_size, num_heads, seq_len, seq_len)
      
  2. 不能直接使用 int 类型的 mask

    • masked_fill 只接受布尔类型作为 mask,所以要确保使用了比较操作如 ==, != 等。

💡 应用场景
场景描述
padding mask防止模型关注到 padding 的 token
look-ahead mask防止 decoder 在预测时看到未来 token
自定义屏蔽机制如屏蔽某些特定词、句子结构等

✅ 总结
方法作用推荐指数
masked_fill(mask == 0, -inf)屏蔽不需要关注的位置⭐⭐⭐⭐⭐
F.softmax(..., dim=-1)使屏蔽位置变为 0⭐⭐⭐⭐
mask 维度适配使用 unsqueeze + expand 调整 mask 到与 attn weight 相同⭐⭐⭐⭐⭐

📌 最佳实践建议
# 假设 attention_mask: (batch_size, seq_len)
# attn_weights: (batch_size, num_heads, seq_len_q, seq_len_k)# Step 1: 添加两个维度,使其匹配 attn_weights 的 shape
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # -> (B, 1, 1, S)# Step 2: 扩展 mask 使得其与 attn_weights 形状完全一致
attention_mask = attention_mask.expand_as(attn_weights)  # -> same shape as attn_weights# Step 3: 应用 mask,填入 -inf
attn_weights = attn_weights.masked_fill(attention_mask == 0, float('-inf'))

这样就能保证每个 head 和 query 的位置都能正确屏蔽掉 pad 或无效 token。

参考材料

https://bruceyuan.com/hands-on-code/from-self-attention-to-multi-head-self-attention.html#%E7%AC%AC%E5%9B%9B%E9%87%8D-multi-head-self-attention

einop库安装及介绍

相关文章:

  • fbdev驱动在rmmod的时候内核崩溃
  • 目标检测学习
  • Word2Vec 生成词向量
  • 考研系列—操作系统:第三章、内存管理
  • KVM——CPU独占
  • FreeRTOS通俗理解指南:基础概念 + 架构+ 内核组件+练手实验
  • LangChain-Tool和Agent结合智谱AI大模型应用实例2
  • 《数字世界的连接器:计算机网络应用全景解析》
  • 使用flex实现三栏布局,两边固定,中间自适应
  • 智能柜I立控信息I产品介绍
  • 八N皇后问题
  • LeetCode Hot100(动态规划)
  • YouTube视频广告指南:类型、投放策略与优劣势解析
  • 传输层核心技术解析
  • [CSS3]响应式布局
  • 主机号全0,代表网络本身地址; 主机号全1,代表广播地址
  • Spring Boot3.4.1 集成 mybatis plus
  • Linux | Shell脚本的常用命令
  • 2. JavaScript 基础:变量、运算符、分支
  • A类地址中最小网络号(0.x.x.x) 默认路由 / 无效/未指定地址
  • 广告网站做动图怎么做/佛山网站建设正规公司
  • 武汉网上商城网站建设/北京网络营销推广外包
  • 北京网站备案核验单/百度企业网盘
  • 做高仿包的网站有哪些/关键词优化排名软件怎么样
  • 网站登录流程图/苏州网站外包
  • 如何做线上赌博的网站/bt搜索引擎最好用的