T5相对位置编码
文章目录
- 核心功能与设计背景
- 代码举例和解读(以T5为例)
- 分步代码解读
- 1. __init__ 初始化方法
- 2. forward 前向传播方法
- 核心特点总结
核心功能与设计背景
在 Transformer 模型中,注意力机制本身是 “位置无关” 的(仅关注内容相似性)。为了让模型理解序列的顺序关系,需要加入位置信息。T5 采用相对位置偏差方案:通过学习不同相对距离对应的偏差值,在注意力权重计算时(在softmax之前,即作为v相乘之后的偏置)进行调整。
代码举例和解读(以T5为例)
class T5PositionalEncoding(nn.Module):"""T5使用的相对位置编码"""# 修正:相对位置偏差应基于注意力头数,而非模型维度def __init__(self, nhead, max_len=5000):super().__init__()self.nhead = nhead # 每个注意力头有独立的相对位置偏差self.max_len = max_len# 相对位置编码参数:嵌入维度改为注意力头数nheadself.relative_attention_bias = nn.Embedding(2 * max_len - 1, nhead)def forward(self, seq_len_q, seq_len_k, device):"""计算相对位置偏差,返回形状为[seq_len_q, seq_len_k, nhead]"""range_vec_q = torch.arange(seq_len_q, device=device)range_vec_k = torch.arange(seq_len_k, device=device)distance_mat = range_vec_k[None, :] - range_vec_q[:, None] # [seq_len_q, seq_len_k]distance_mat_clamped = torch.clamp(distance_mat, -self.max_len + 1, self.max_len - 1)final_mat = distance_mat_clamped + self.max_len - 1 # 偏移到非负索引return self.relative_attention_bias(final_mat) # [seq_len_q, seq_len_k] 经过[2 * max_len - 1, nhead] --> [seq_len_q, seq_len_k, nhead] ,即有seq_len_q x seq_len_k 个词嵌入编码
这段代码实现了 T5 模型中使用的相对位置编码(Relative Positional Encoding) 机制,用于在注意力计算中引入位置信息。与绝对位置编码不同,相对位置编码关注的是序列中元素之间的相对距离,更符合自然语言中 “位置关系比绝对位置更重要” 的特性。
分步代码解读
1. init 初始化方法
def __init__(self, nhead, max_len=5000):super().__init__()self.nhead = nhead # 注意力头数量self.max_len = max_len # 最大序列长度(限制相对距离范围)# 相对位置偏差的嵌入层self.relative_attention_bias = nn.Embedding(2 * max_len - 1, nhead)
核心参数:
nhead:注意力头的数量(每个头独立学习相对位置偏差)。
max_len:允许的最大序列长度,用于限制相对距离的范围(避免距离过大导致偏差学习不稳定)。
关键设计:
nn.Embedding(2 * max_len - 1, nhead):
嵌入层的输入维度是 2 * max_len - 1(对应可能的相对距离范围),输出维度是 nhead(每个注意力头有独立的偏差参数)。
例如:当 max_len=5000 时,相对距离范围是 [-4999, 4999],共 2*5000-1=9999 种可能的距离,因此嵌入层输入维度为 9999。
2. forward 前向传播方法
def forward(self, seq_len_q, seq_len_k, device):"""返回形状为 [seq_len_q, seq_len_k, nhead] 的相对位置偏差"""# 1. 生成查询和键的位置索引range_vec_q = torch.arange(seq_len_q, device=device) # [seq_len_q]range_vec_k = torch.arange(seq_len_k, device=device) # [seq_len_k]# 2. 计算相对距离矩阵distance_mat = range_vec_k[None, :] - range_vec_q[:, None] # [seq_len_q, seq_len_k]# 示例:若q长度=2,k长度=3,结果为:# [[0-0, 1-0, 2-0],# [0-1, 1-1, 2-1]] → [[0,1,2], [-1,0,1]]# 3. 限制距离范围(防止超出max_len)distance_mat_clamped = torch.clamp(distance_mat, -self.max_len + 1, # 最小距离(如-4999)self.max_len - 1 # 最大距离(如4999))# 4. 将距离转为非负索引(嵌入层需要非负输入)final_mat = distance_mat_clamped + self.max_len - 1 # 偏移量:max_len-1# 示例:距离-4999 → 0,距离0 → 4999,距离4999 → 9998# 5. 查找对应的相对位置偏差return self.relative_attention_bias(final_mat) # [seq_len_q, seq_len_k, nhead]
输入参数:
seq_len_q:查询序列(Query)的长度。
seq_len_k:键序列(Key)的长度。
device:计算设备(CPU/GPU),确保张量位置正确。
核心计算步骤:
生成查询和键的位置索引(0 到长度 - 1)。
计算每个 Query 位置与 Key 位置的相对距离(k的位置 - q的位置),得到二维距离矩阵。
限制距离范围(超出max_len的距离被截断),避免极端值影响。
将负距离转为非负索引(通过加偏移量),以便作为嵌入层的输入。
通过嵌入层获取每个相对距离对应的偏差值,最终输出形状为 [seq_len_q, seq_len_k, nhead]。
与注意力机制的结合
该模块的输出(相对位置偏差)会在注意力权重计算时被加入,公式大致为:
scores = (Q @ K.T) / sqrt(d_k) # 原始内容相似度分数
scores += relative_bias # 加入相对位置偏差(本文代码的输出)
attention_weights = F.softmax(scores, dim=-1)
通过这种方式,模型在计算注意力时不仅考虑内容相似度,还会受到位置关系的影响(例如 “附近的词权重更高”)。
核心特点总结
相对位置建模:不依赖绝对位置,而是关注元素间的相对距离,更适合长序列和动态位置场景。
多头独立学习:每个注意力头有独立的相对位置偏差参数(nhead维度),适配不同头的关注重点。([seq_len_q, seq_len_k] 经过[2 * max_len - 1, nhead] --> [seq_len_q, seq_len_k, nhead] ,即有seq_len_q x seq_len_k 个词嵌入编码)
范围限制:通过max_len控制最大相对距离,避免模型学习过多稀疏的远距离偏差,提升效率。
T5 位置编码方式是其注意力机制的重要组成部分,广泛应用于文本生成、机器翻译等任务。