Transformer-输入部分
Transformer输入部分实现
一、学习目标
-
了解文本嵌入层和位置编码的作用
-
掌握文本嵌入层和位置编码的实现过程
二、输入部分介绍
Transformer的输入部分包含两个主要组件:
-
源文本嵌入层及其位置编码器
-
目标文本嵌入层及其位置编码器
这些组件共同作用,将文本中的词汇数字表示转换为包含位置信息的向量表示。
三、文本嵌入层(Embeddings)
3.1 作用
文本嵌入层的主要作用是将文本中词汇的数字表示转变为向量表示,在高维空间中捕捉词汇间的关系。无论是源文本嵌入还是目标文本嵌入,都遵循相同的原理。
3.2 实现代码分析
import torch
import torch.nn as nn
import math
from torch.autograd import Variableclass Embeddings(nn.Module):def __init__(self, d_model, vocab):"""参数说明:- d_model: 每个词汇的特征尺寸(词嵌入维度)- vocab: 词汇表大小"""super(Embeddings, self).__init__()self.d_model = d_modelself.vocab = vocab# 定义词嵌入层self.lut = nn.Embedding(self.vocab, self.d_model)def forward(self, x):# 将x传给self.lut并与根号下self.d_model相乘作为结果返回return self.lut(x) * math.sqrt(self.d_model)
3.3 关键点说明
-
缩放因子:乘以
math.sqrt(self.d_model)
是为了增大x的值,使得词嵌入后的embedding vector与位置编码信息的量纲相近 -
nn.Embedding:PyTorch内置的嵌入层,将索引映射为密集向量
3.4 使用示例
def test_Embeddings():d_model = 512 # 词嵌入维度是512维vocab = 1000 # 词表大小是1000# 实例化词嵌入层my_embeddings = Embeddings(d_model, vocab)# 输入数据:2个句子,每个句子4个词汇x = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))embed = my_embeddings(x)print('embed.shape:', embed.shape) # 输出: torch.Size([2, 4, 512])
四、位置编码器(PositionalEncoding)
4.1 作用
由于Transformer的编码器结构没有专门处理词汇位置信息,需要在Embedding层后加入位置编码器,将词汇位置信息加入到词嵌入张量中,以弥补位置信息的缺失。
4.2 实现代码分析
class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout, max_len=5000):"""参数说明:- d_model: 词嵌入维度- dropout: 丢弃率- max_len: 最大序列长度"""super(PositionalEncoding, self).__init__()# 定义dropout层self.dropout = nn.Dropout(p=dropout)# 初始化位置编码矩阵pe = torch.zeros(max_len, d_model)# 创建位置序列 [0, 1, 2, ..., max_len-1]position = torch.arange(0, max_len).unsqueeze(1)# 计算变化矩阵(用于位置编码的缩放)div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))# 应用正弦和余弦函数到奇偶数列pe[:, 0::2] = torch.sin(position * div_term) # 偶数列使用正弦pe[:, 1::2] = torch.cos(position * div_term) # 奇数列使用余弦# 调整形状并注册为缓冲区(不参与训练)pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):# 将位置编码添加到输入中x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)return self.dropout(x)
4.3 关键点说明
-
正弦余弦编码:使用不同频率的正弦和余弦函数为不同位置生成独特的编码
-
缓冲区注册:使用
register_buffer
将位置编码矩阵注册为模型缓冲区,不参与训练但会随模型保存/加载 -
位置信息添加:通过简单的加法将位置编码信息融合到词嵌入中
4.4 使用示例
def test_PositionalEncoding():d_model = 512vocab = 1000# 创建词嵌入层和位置编码器embeddings = Embeddings(d_model, vocab)positional_encoding = PositionalEncoding(d_model, dropout=0.1, max_len=60)# 处理输入数据x = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))embed = embeddings(x)pe_result = positional_encoding(embed)print('pe_result.shape:', pe_result.shape) # 输出: torch.Size([2, 4, 512])
五、位置编码特征可视化
5.1 可视化代码
import matplotlib.pyplot as plt
import numpy as npdef draw_PE_feature():# 创建位置编码器(小维度便于可视化)my_pe = PositionalEncoding(d_model=20, dropout=0)# 创建输入数据y = my_pe(Variable(torch.zeros(1, 100, 20)))# 绘制特征曲线plt.figure(figsize=(20, 20))plt.plot(np.arange(100), y[0, :, 4:8].numpy())plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])plt.show()
5.2 可视化结果分析
-
曲线特征:每条颜色的曲线代表某一个词汇特征在不同位置的含义
-
位置敏感性:保证同一词汇随着所在位置不同,其对应位置嵌入向量会发生变化
-
数值控制:正弦波和余弦波的值域范围为[-1, 1],有助于控制嵌入数值大小和梯度计算
六、总结
6.1 文本嵌入层
-
作用:将词汇数字表示转换为向量表示,捕捉词汇间关系
-
实现要点:
-
使用nn.Embedding进行词嵌入
-
通过乘以
sqrt(d_model)
进行数值缩放 -
输出形状为[batch_size, seq_len, d_model]
-
6.2 位置编码器
-
作用:为模型提供词汇位置信息,弥补Transformer架构的位置不敏感性
-
实现要点:
-
使用正弦余弦函数生成位置编码
-
奇偶数列分别使用正弦和余弦函数
-
通过加法将位置信息融合到词嵌入中
-
6.3 整体流程
-
输入词汇索引 → 文本嵌入层 → 得到词向量
-
词向量 + 位置编码 → 得到包含位置信息的词表示
-
通过Dropout进行正则化处理
通过这种设计,Transformer的输入部分能够有效地将词汇的语义信息和位置信息结合起来,为后续的编码器和解码器提供丰富的输入表示。