python 实现 transformer 的 position embeding
import numpy as np
import matplotlib.pyplot as pltclass PositionalEmbedding:def __init__(self, d_model, max_seq_len):"""初始化位置嵌入参数:d_model: 嵌入维度max_seq_len: 最大序列长度"""self.d_model = d_modelself.max_seq_len = max_seq_lenself.pos_embedding = self._create_positional_embedding()def _create_positional_embedding(self):"""创建正弦余弦位置嵌入"""# 初始化位置嵌入矩阵pos_embedding = np.zeros((self.max_seq_len, self.d_model))# 遍历每个位置for pos in range(self.max_seq_len):# 遍历每个维度for i in range(self.d_model // 2):# 计算正弦和余弦值angle = pos / np.power(10000, 2 * i / self.d_model)pos_embedding[pos, 2*i] = np.sin(angle)pos_embedding[pos, 2*i + 1] = np.cos(angle)return pos_embeddingdef get_embedding(self, pos):"""获取指定位置的嵌入向量"""if pos < 0 or pos >= self.max_seq_len:raise ValueError(f"位置必须在[0, {self.max_seq_len-1}]范围内")return self.pos_embedding[pos]def visualize_embedding(self, num_positions=10, figsize=(12, 8)):"""可视化位置嵌入"""plt.figure(figsize=figsize)# 只显示前num_positions个位置和所有维度plt.imshow(self.pos_embedding[:num_positions, :], cmap='viridis')plt.xlabel('嵌入维度')plt.ylabel('位置')plt.title('Transformer位置嵌入可视化')plt.colorbar()plt.show()def visualize_dimensions(self, dim1=0, dim2=1, num_positions=50, figsize=(10, 10)):"""可视化不同位置在两个维度上的分布"""plt.figure(figsize=figsize)x = self.pos_embedding[:num_positions, dim1]y = self.pos_embedding[:num_positions, dim2]plt.scatter(x, y)# 为每个点添加位置标签for i in range(num_positions):plt.annotate(str(i), (x[i], y[i]))plt.xlabel(f'维度 {dim1}')plt.ylabel(f'维度 {dim2}')plt.title(f'位置在维度 {dim1} 和 {dim2} 上的分布')plt.grid(True)plt.show()# 示例用法
if __name__ == "__main__":# 创建位置嵌入实例,模型维度为512,最大序列长度为100pos_embed = PositionalEmbedding(d_model=512, max_seq_len=100)# 获取位置10的嵌入向量position_10_embedding = pos_embed.get_embedding(10)print(f"位置10的嵌入向量形状: {position_10_embedding.shape}")# 可视化位置嵌入pos_embed.visualize_embedding(num_positions=20)# 可视化不同位置在两个维度上的分布pos_embed.visualize_dimensions(dim1=0, dim2=1)pos_embed.visualize_dimensions(dim1=2, dim2=3)