Transformer 小记(一):深入理解 Transformer 中的位置关系
Transformer 小记(一):深入理解 Transformer 中的位置关系
- 深入理解 Transformer 中的位置编码:从原理到代码实现
- 1. 为什么需要位置编码?
- 2. 正弦位置编码的原理:波浪的组合
- 3. 代码实现详解:`SinusoidalPosEmb` 类
- 数值例子:`dim=512`, `x=torch.tensor([0, 1, 2, 3, 4, 5])`
- 步骤 1: 计算 `half_dim`
- 步骤 2. 计算 `freq_list` (256 个不同的“频率基数”)
- 步骤 3. 计算 `angle_at_pos_and_freq` (“在哪个角度”)
- 步骤 4. 计算 `sin_values`, `cos_values` 并拼接成 `identity_card`
- `emb` 是什么?
- `emb` 的作用是什么?
深入理解 Transformer 中的位置编码:从原理到代码实现
在深度学习,特别是自然语言处理领域,Transformer 模型凭借其强大的并行处理能力和对长距离依赖的建模优势,彻底改变了我们处理序列数据的方式。然而,Transformer 的核心——自注意力机制——本身并不包含任何关于序列中元素顺序或位置的信息。这意味着,如果仅仅依赖自注意力,打乱一个句子的词序,模型也可能得到相同的输出。
为了弥补这一缺陷,位置编码 (Positional Encoding) 应运而生。它负责将每个元素在序列中的相对或绝对位置信息编码成一个向量,然后将其注入到模型的输入表示中。本文将深入探讨 Transformer 中一种经典且无需学习的位置编码方式:正弦位置编码 (Sinusoidal Positional Embeddings)。
1. 为什么需要位置编码?
传统的循环神经网络 (RNN) 通过顺序处理数据来自然地捕获位置信息。但 Transformer 采用并行计算,其自注意力机制会同时关注序列中的所有位置。如果不对输入进行额外处理,模型将无法区分 “我爱你” 和 “你爱我” 中词语的顺序。
因此,为了让 Transformer 理解序列的顺序,我们需要将位置信息明确地编码到输入中。
2. 正弦位置编码的原理:波浪的组合
正弦位置编码的核心思想是利用不同频率的正弦和余弦函数来为每个位置生成一个独特且固定维度的向量。这个向量就像是每个位置的“指纹”或“身份卡”。
它的巧妙之处在于:
-
高维唯一性: 单独一个正弦或余弦函数由于其周期性,可能会在不同位置产生相同的值。但正弦位置编码不只使用一个,而是使用多对频率不同的正弦和余弦函数。将所有这些函数在某个位置产生的值拼接起来,就形成了一个高维向量。
- 想象你有很多把同样长,但波浪纹路(频率)疏密程度不同的尺子。
- 要给位置
P
制作“身份卡”:你将位置P
对准每一把尺子,记录下它在每把尺子上波浪的高度(正弦值)和坡度(余弦值)。 - 最后,把所有尺子的读数串起来,就成了位置
P
独一无二的“身份卡”。 - 即使某一把尺子在不同位置
P1
和P2
处读数相同,但因为其他尺子的波浪疏密不同,所有尺子在P1
的读数组合与在P2
的读数组合,几乎不可能完全相同。维度越高,这种独特性就越强。
-
编码相对位置: 这种设计还能够有效捕获相对位置信息。对于任意两个位置
pos
和pos + k
,它们的编码之间存在着线性关系。这意味着模型可以通过学习一些简单的线性变换,来理解“相距k
个位置”这样的相对距离,而无需为每个可能的相对距离单独学习一个表示。
3. 代码实现详解:SinusoidalPosEmb
类
以https://github.com/anthonysimeonov/rpdiff代码库中的位置编码为例
在给定的代码中,SinusoidalPosEmb
类负责生成这种位置编码。
class SinusoidalPosEmb:def __init__(self, dim, max_pos=None):super().__init__()self.dim = dim # 最终位置嵌入向量的维度self.max_pos = max_pos # 最大位置值,用于裁剪输入def __call__(self, x): # x 就是输入的原始位置索引,比如 [0, 1, 2, 3, 4, 5]device = x.deviceif self.max_pos is not None:x = torch.clip(x, 0, self.max_pos) # 可选:将位置索引裁剪到最大值half_dim = self.dim // 2 # 维度减半,因为一半用于sin,一半用于cos# 计算频率基数:这决定了每对sin/cos波浪的“疏密程度”# 公式中的 1 / (10000^(2i/dim)) 在这里通过 exp(log(1/10000) * 2i/dim) 实现emb_scale_factor = math.log(10000) / (half_dim - 1)freq_list = torch.exp(torch.arange(half_dim, device=device) * -emb_scale_factor)# 此时,freq_list 是一个形状为 (half_dim,) 的张量,其元素就是 f_0, f_1, ..., f_{half_dim-1}# f_0 最大(频率最低),f_{half_dim-1} 最小(频率最高)# 核心:将每个位置 x 乘以每个频率,得到“在哪个角度”# x_expanded 形状 (batch_size, 1), freq_list_expanded 形状 (1, half_dim)# 结果 angle_at_pos_and_freq 形状 (batch_size, half_dim)# 举例:angle_at_pos_and_freq[i, j] = 位置 x_i * 第 j 个频率 f_jx_expanded = x[:, None]freq_list_expanded = freq_list[None, :]angle_at_pos_and_freq = x_expanded * freq_list_expanded# 计算正弦和余弦分量,并拼接sin_values = angle_at_pos_and_freq.sin() # 形状 (batch_size, half_dim)cos_values = angle_at_pos_and_freq.cos() # 形状 (batch_size, half_dim)# 最终拼接成一个 (batch_size, dim) 的完整位置编码向量identity_card = torch.cat((sin_values, cos_values), dim=-1)return identity_card # 返回所有位置的身份卡(位置编码向量)
数值例子:dim=512
, x=torch.tensor([0, 1, 2, 3, 4, 5])
- 假设:
dim
(位置编码向量的最终维度) = 512 (这是 Transformer 中常见的维度)x
(输入的位置,一个批次中的 6 个连续位置索引) =torch.tensor([0., 1., 2., 3., 4., 5.])
device
='cpu'
(为了简化,不涉及 GPU)
步骤 1: 计算 half_dim
half_dim = self.dim // 2
self.dim
是 512,所以half_dim
=512 // 2
= 256。- 这意味着我们会有 256 对正弦/余弦函数,也就是 256 把“波浪稀疏度不同的尺子”。
步骤 2. 计算 freq_list
(256 个不同的“频率基数”)
emb = math.log(10000) / (half_dim - 1)
freq_list = torch.exp(torch.arange(half_dim, device=device) * -emb)
-
计算
emb
(衰减常数):emb
=math.log(10000)
/ (256 - 1
) ≈0.03612
。
-
生成
torch.arange(half_dim, device=device)
:torch.arange(256, device='cpu')
会得到张量[0., 1., 2., ..., 255.]
。
-
计算
freq_list
(256 个频率基数):exp(0 * -0.03612)
=1.0
(f_0)exp(1 * -0.03612)
≈0.9645
(f_1)exp(2 * -0.03612)
≈0.9292
(f_2)- …
exp(255 * -0.03612)
≈0.0001
(f_255)
freq_list
现在是包含这 256 个不同频率的张量,从1.0
递减到0.0001
。
步骤 3. 计算 angle_at_pos_and_freq
(“在哪个角度”)
x_expanded = x[:, None]
freq_list_expanded = freq_list[None, :]
angle_at_pos_and_freq = x_expanded * freq_list_expanded
输入 x
是 torch.tensor([0., 1., 2., 3., 4., 5.])
。
-
x_expanded
: 形状(6, 1)
,内容是[[0.], [1.], ..., [5.]]
。 -
freq_list_expanded
: 形状(1, 256)
,内容是[[f_0, f_1, ..., f_255]]
。 -
angle_at_pos_and_freq
(广播相乘): 结果形状是(6, 256)
。-
对于位置 0 (第一行):
[0. * f_0, 0. * f_1, ..., 0. * f_255]
=[0.0, 0.0, ..., 0.0]
-
对于位置 1 (第二行):
[1. * f_0, 1. * f_1, ..., 1. * f_255]
=[1.0, 0.9645, ..., 0.0001]
-
对于位置 2 (第三行):
[2. * f_0, 2. * f_1, ..., 2. * f_255]
=[2.0, 1.929, ..., 0.0002]
-
… (依此类推,直到位置 5)
angle_at_pos_and_freq
看起来像这样 (只展示前两行和部分列,以及最后一行):[[0.0, 0.0, ..., 0.0 ], <- 位置 0 的角度[1.0, 0.9645, ..., 0.0001], <- 位置 1 的角度[2.0, 1.929, ..., 0.0002], <- 位置 2 的角度...[5.0, 4.8225, ..., 0.0005]] <- 位置 5 的角度
-
步骤 4. 计算 sin_values
, cos_values
并拼接成 identity_card
sin_values = angle_at_pos_and_freq.sin()
cos_values = angle_at_pos_and_freq.cos()
identity_card = torch.cat((sin_values, cos_values), dim=-1)
-
sin_values
(形状(6, 256)
):- 位置 0 的
sin
值:[sin(0.0), sin(0.0), ..., sin(0.0)]
=[0.0, 0.0, ..., 0.0]
- 位置 1 的
sin
值:[sin(1.0), sin(0.9645), ..., sin(0.0001)]
sin(1.0)
≈0.841
sin(0.9645)
≈0.823
sin(0.0001)
≈0.0001
- 位置 2 的
sin
值:[sin(2.0), sin(1.929), ..., sin(0.0002)]
sin(2.0)
≈0.909
sin(1.929)
≈0.938
sin(0.0002)
≈0.0002
- …
- 位置 0 的
-
cos_values
(形状(6, 256)
):- 位置 0 的
cos
值:[cos(0.0), cos(0.0), ..., cos(0.0)]
=[1.0, 1.0, ..., 1.0]
- 位置 1 的
cos
值:[cos(1.0), cos(0.9645), ..., cos(0.0001)]
cos(1.0)
≈0.540
cos(0.9645)
≈0.572
cos(0.0001)
≈1.0
- 位置 2 的
cos
值:[cos(2.0), cos(1.929), ..., cos(0.0002)]
cos(2.0)
≈-0.416
cos(1.929)
≈-0.366
cos(0.0002)
≈1.0
- …
- 位置 0 的
-
identity_card
(最终拼接,形状(6, 512)
):- 位置 0 的“身份卡”:
[0.0, ..., 0.0 (256个), | 1.0, ..., 1.0 (256个)]
- 位置 1 的“身份卡”:
[0.841, 0.823, ..., 0.0001, | 0.540, 0.572, ..., 1.0]
- 位置 2 的“身份卡”:
[0.909, 0.938, ..., 0.0002, | -0.416, -0.366, ..., 1.0]
- … (依此类推,直到位置 5)
- 位置 0 的“身份卡”:
emb
是什么?
简单来说,在 SinusoidalPosEmb
类中,最终返回的 emb
(在代码中是 identity_card
)是一个包含了位置信息的向量(或者说“身份卡”)。
emb
是一个稠密的、固定维度(self.dim
)的浮点数向量。这个向量通过结合不同频率的正弦和余弦函数,编码了你输入的位置信息。
我们可以用一个比喻来理解它:
想象一下,你想要给一个数字(比如位置 5
)打上一个特殊的“指纹”。这个指纹不是简单的图案,而是一串数字。
- 你有许多不同粗细(频率)的“波纹扫描仪”。
- 你把位置
5
这个数字放入每一个扫描仪。 - 每个扫描仪都会根据它自己的波纹粗细,给出一个关于位置
5
的独特读数(包括波纹的高度和坡度)。 - 你把所有这些不同扫描仪的读数收集起来,串成一长串数字。
这条长串数字,就是最终返回的 emb
。
emb
的作用是什么?
这个 emb
向量(“身份卡”或“指纹”)的作用是:
-
为模型提供位置信息: Transformer 模型本身不理解顺序。通过将这个
emb
向量与原始的输入特征向量(例如词嵌入)相加,模型就能够感知到序列中每个元素的位置。- 比如,如果一个词在句子中的位置是
5
,它的词嵌入向量会加上位置5
对应的emb
向量。这样,模型在处理这个词时,不仅知道它是什么词,还知道它在句子的哪个位置。
- 比如,如果一个词在句子中的位置是
-
区分不同位置: 即使两个位置的原始特征相同,它们的
emb
向量不同,相加后得到的最终表示也就不同,从而模型能够区分它们。 -
编码相对位置: 这种正弦编码的数学特性使得模型能够更容易地学习和理解位置之间的相对关系(例如,“在它前面一个位置”或“距离它五个位置”)。这是通过模型学习对这些
emb
向量进行线性变换来实现的。