深入解析PyTorch中MultiheadAttention的隐藏参数add_bias_kv与add_zero_attn
关键背景
最近在学习pytorch中的源码尤其是nn.modules下算子的实现,针对activation.py
下MultiheadAttention
下有两个不常见的参数的使用比较有趣,因为时序领域很少使用这两个参数(add_bias_kv
和add_zero_attn
),但是其目的似乎很适配时序场景,尽管逻辑上听起来其直接简单,但是还是打算手动推导分析其具体的变换。以熟悉其具体的变换。
参数作用
源码中针对其解释如下:
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.Default: ``False``.
- add_bias_kv
- 描述:当设置
add_bias_kv=True
时,PyTorch 会在每个 attention 计算的key和value前,拼接一个 earnable向量(即bias_k, bias_v),其shape是 (1, batch_size, embed_dim)。这些bias会被视作额外的“位置”,放在序列的开头。 - 作用:这个机制可以让模型在每个序列前加上一个全局性的learnable token,类似于BERT 的 [CLS] token、ViT(Vision Transformer)的 class token、引入统一的上下文引导 token
- 描述:当设置
- add_zero_attn
- 描述:意思是,在 每个位置的 key/value 向量中都加上一个额外的
“零向量”
,并作为一个新元素插入到 batch 的 attention 范围内。 - 作用:加一个 zero-attention vector 的目的是给模型留
一个“什么都不关注”的退路
;在 encoder-decoder attention 中,decoder 可以选择性不关注任何输入位置,退回到这个零向量
;对某些任务(如语义空洞填充、视觉 patch drop)来说,zero attention 能表示“不存在”或“缺失”的信息源
;
- 描述:意思是,在 每个位置的 key/value 向量中都加上一个额外的
分析流程
✅1. 从二维矩阵开始理解
假设我们有一个序列(时序 or 文本),长度是5,模型的维度是8维,即每个token将被嵌入成一个8维的向量,那么基于seq_len = 5、embed_dim = 8显然可以得到一个shape为(5,8)的矩阵,使用torch随机生成这样一个矩阵,如下:
import torchseq_len = 5
embed_dim = 8
m = torch.randn(seq_len,embed_dim)print(m.shape)
print(m)
输出如下:
torch.Size([5, 8])
tensor([[-7.4508e-01, 1.1659e+00, -7.0335e-02, 4.3215e-01, 1.3831e-01,-1.6028e+00, 2.6052e+00, 2.9472e-02],[ 6.7109e-01, -1.2629e-01, 1.3738e+00, 7.8396e-01, 1.5244e+00,5.7940e-01, -1.1636e+00, 1.1213e+00],[-3.4683e-01, 3.3295e-01, -9.1225e-02, 4.5248e-01, 1.8235e+00,-2.4852e-01, 1.0417e+00, -1.2556e-01],[-1.0605e+00, 4.5711e-01, -7.9260e-01, -2.0586e+00, -2.5313e-01,-8.0461e-01, 9.3312e-01, 4.7544e-01],[-5.2117e-01, 3.4502e-01, 1.4715e+00, 2.4684e+00, 9.4748e-01,2.0253e-03, 1.0036e+00, 4.8027e-01]])
✅2. 引入 batch 的概念(3D 张量)
一般,我们需要在一次性处理多个序列(多个样本),假设3个序列样本组成一个batch。每个序列的长度仍是5,token的维度为8。那么此时这个张量的形状就变成了:
import torchseq_len = 5
embed_dim = 8
batch_size = 3
q = torch.randn(batch_size, seq_len, embed_dim)
print(q.shape)
print(q)
输出为:
torch.Size([3, 5, 8])
tensor([[[-0.3007, -0.3443, 0.0515, 0.9153, 0.1486, 0.8630, -0.1750,-0.6688],[ 0.1970, -0.8177, -0.0302, 1.1665, 0.3290, 0.6600, -0.7473,-1.2262],[-0.3780, 0.6538, 1.3766, 0.1920, -1.0980, 0.0694, 0.8015,0.3631],[ 1.1727, -0.1484, 1.5107, 1.4208, -0.2864, -1.7283, 0.5781,-1.4435],[ 1.3020, -0.1518, -0.9987, 0.5897, -1.1685, 1.1592, 0.0360,-1.1931]],[[ 1.3530, 0.0892, -1.2635, 1.8082, 1.3397, 1.0009, -1.3071,0.0946],[-0.3749, -1.5674, -0.8663, -1.3531, 0.9437, -1.1769, -1.3152,-1.1854],[ 0.6995, 0.6464, -0.8311, 0.4104, 1.4770, -0.2067, 0.8549,-0.0366],[-0.3462, 1.0118, -1.3090, -1.5885, -0.1143, 0.1957, -1.1694,-0.1317],[-0.0216, 0.7810, 1.6990, -0.2328, -0.0163, -1.5569, -0.9106,-1.5693]],[[ 0.0365, -0.8511, -0.6117, -1.4029, -0.5794, 0.7073, 0.0607,2.2900],[ 1.6539, 0.4874, 1.0456, 0.2727, 1.0852, 1.7963, -0.4513,-0.9612],[-1.4896, 1.8739, -0.3650, -0.0476, -2.5191, -1.4645, 0.5743,0.4616],[-1.2099, 0.3355, -0.8877, 2.6665, -0.6601, -1.2705, 1.0287,-0.6931],[-0.3273, 0.2364, -1.2982, -0.6908, 1.5833, -0.2403, 1.2128,1.4706]]])
我们可以尝试使用transpose
转置该3D张量对应的维度,如下:
q1 = q.transpose(0,1)
print(q1.shape)
print(q1)
输出如下:
torch.Size([5, 3, 8])
tensor([[[-0.0687, 0.3209, -0.5212, -0.1787, 0.4720, 1.0013, 0.6243,-1.8285],[-0.6348, -0.4394, -0.1964, 0.2261, -0.1205, 0.4492, 0.9841,-0.4095],[ 0.7714, -0.1087, -0.7359, 0.2492, -0.0391, -0.2462, -0.1695,1.4089]],[[-0.5181, -0.5992, 0.1055, 0.4877, 0.1648, 0.5122, -0.3526,1.7066],[ 1.5172, 1.4660, -0.2405, 2.1547, -0.8794, -1.6543, -2.0169,-0.5331],[ 0.3779, 0.4134, 1.9286, 0.3782, 1.5611, -1.6187, 1.6274,1.0527]],[[ 0.5642, -0.3944, -1.4383, -1.1361, 0.0242, 0.1435, 0.5510,-0.0472],[-1.0935, 0.0820, 0.5193, 0.1174, 0.3282, -1.9772, 0.4186,-0.5007],[-0.0845, -2.0364, 0.1124, 1.7474, -0.3131, -1.4156, 0.4046,-0.3282]],[[-2.4212, -0.4703, 1.5794, 0.8093, 0.9247, -1.4775, 0.4462,0.4256],[-0.0934, -0.2569, 0.4803, -0.3651, 0.7175, -1.0460, 0.9095,0.6421],[ 0.1579, 2.0790, -0.4982, 1.7707, -0.3657, 0.7336, -0.1482,-1.5648]],[[ 2.7056, 2.2962, 0.7005, 0.6427, 0.7578, -0.4191, 0.9064,-0.3934],[ 0.1987, 1.6104, 0.4723, 1.5453, 0.0500, -0.5176, -1.8852,-1.2235],[ 1.2145, 1.7694, -0.1546, 0.3803, 0.0489, 1.0129, 0.0513,-0.6902]]])
实际上,未转置之前是很容易理解的,即将将样本按照batch数进行了聚合,但是转置之后如何直观理解是非常重要的一个事情:
- 转换前
(batch_size, seq_len, embed_dim)
我们可以把它看成一个“batch_size 个句子”,每个句子有 seq_len 个词,每个词是一个 embed_dim 维的向量。
- 转换后
现在你可以把它理解成:我们将多个句子按词的位置对齐,从每个 batch 的样本中抽出相同位置的词,堆叠在一起。即第 i 个时间步(即位置 i)下,batch 中所有样本的向量被统一收集起来。
✅3. 输入到MultiheadAttention中
import torch
import torch.nn as nn# 定义参数
embed_dim = 8
num_heads = 2
seq_len = 5
batch_size = 3# 创建 MultiheadAttention 实例
mha_with_bias = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)# 输入的 query, key, value (shape: seq_len, batch_size, embed_dim)
query = torch.randn(batch_size, seq_len, embed_dim)
key = torch.randn(batch_size, seq_len, embed_dim)
value = torch.randn(batch_size, seq_len, embed_dim)# 计算 attention 输出
out_with_bias, _ = mha_with_bias(query, key, value)# 输出两种情况下的结果
print(out_with_bias.shape)
print(out_with_bias)
输出为
torch.Size([3, 5, 8])
tensor([[[ 0.0749, 0.0814, -0.1367, -0.0314, 0.0061, -0.2369, -0.1453,-0.0241],[-0.0945, -0.0934, 0.0407, -0.2193, -0.2172, -0.1352, -0.0820,-0.1239],[-0.0591, -0.0986, 0.1017, -0.2430, -0.2136, -0.0989, -0.1641,-0.1050],[-0.0588, -0.0747, 0.0355, -0.2210, -0.1803, -0.1326, -0.0864,-0.1240],[-0.1035, -0.0213, -0.0752, -0.1356, -0.1339, -0.1776, -0.0013,-0.1460]],[[-0.0217, -0.0490, 0.2707, 0.2102, -0.0289, 0.2783, -0.2179,0.1797],[-0.1059, -0.0113, 0.2640, 0.3222, 0.0088, 0.3493, -0.2244,0.2725],[-0.0965, -0.0284, 0.4591, 0.2321, 0.0041, 0.3915, -0.2269,0.2789],[-0.0432, -0.0149, 0.5781, 0.2857, 0.0851, 0.5095, -0.2417,0.3853],[-0.3742, 0.0287, 0.2859, 0.4340, -0.0349, 0.4007, -0.2363,0.3838]],[[-0.5256, 0.0713, -0.3210, 0.1768, -0.0383, 0.0495, 0.1960,-0.1594],[-0.4588, 0.1027, -0.4966, 0.0540, -0.0393, -0.1755, 0.3327,-0.2415],[-0.4845, 0.0701, -0.2647, 0.2300, -0.0053, 0.0749, 0.1903,-0.0888],[-0.4795, 0.1241, -0.3832, 0.1589, 0.0214, -0.1044, 0.3101,-0.1343],[-0.4665, 0.1103, -0.3181, 0.2179, 0.0319, -0.0670, 0.3088,-0.0676]]], grad_fn=<TransposeBackward0>)
为什么 PyTorch 使用 (seq_len, batch_size, embed_dim)?这是为了兼容早期的 RNN 接口(nn.RNN, nn.LSTM)以及底层高效的张量操作。可以用 batch_first=True 的方式让维度变成 [batch_size, seq_len, embed_dim],这是更自然的顺序,但默认 MultiheadAttention 期望 [seq_len, batch_size, embed_dim]。
✅4. 分析add_bias_kv
按照上面的理解官方实现中要求[batch_size, seq_len, embed_dim],并且按照之前的解释其在每个 attention 计算的key和value前,拼接一个 learnable向量(即bias_k, bias_v),其shape是 (1, batch_size, embed_dim),这个拼接逻辑如何理解呢?具体的,
- 其会创建两个可学习的参数向量bias_k和bias_v,且形状为(1,1,embed_dim)
- 然后在前向传播时,会将这个 bias 向量沿着 batch 维复制为 [batch_size, 1, embed_dim],然后 附加在原始的 key 和 value 张量的序列维 dim=1 上(注意是 seq_len 那一维)
key = torch.cat([key, bias_k.expand(3, 1, 8)], dim=1)
value = torch.cat([value, bias_v.expand(3, 1, 8)], dim=1)
新shape:[seq_len + 1, batch_size, embed_dim]
✅5. 分析add_zero_attn
add_zero_attn=True 表示在 key 和 value 的序列维(即 seq_len 那一维)前面追加一行全为零的向量。注意query始终不变。即
zero_tensor = torch.zeros(batch_size, 1, embed_dim)
key = torch.cat([zero_tensor, key], dim=1)
value = torch.cat([zero_tensor, value], dim=1)
[seq_len, batch_size + 1, embed_dim]
✅6.特别注意
- add_bias_kv 添加的是 bias_k 和 bias_v,这两个是模型的参数,会随着训练学习出有意义的表示(比如代表某类全局特征或辅助 attention 的 anchor 点)。
- 而 add_zero_attn 加的是一个 固定为 0 的向量,它不会被训练,更多是出于工程目的(例如 decoder 中让 query 有可能对“无意义”的位置分配注意力,从而在 masking 时更稳定)。