当前位置: 首页 > news >正文

仅仅使用pytorch来手撕transformer架构(1):位置编码的类的实现和向前传播

如果对transformer架构原理模糊的话可以看最适合小白入门的Transformer介绍
P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d model ) , PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right), PE(pos,2i)=sin(100002i/dmodelpos),

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d model ) . PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right). PE(pos,2i+1)=cos(100002i/dmodelpos).

其中,pos表示位置,d_model表示模型的维度,一般设置为512。

PositionalEncoding 类是 Transformer 模型中一个非常重要的组件,它的作用是为输入的嵌入向量(embedding)添加位置信息。Transformer 是一个基于序列的模型,但它本身并不像 RNN 那样能够自然地处理序列中的位置信息。因此,需要通过某种方式将位置信息注入到模型中,PositionalEncoding 就是实现这一功能的关键部分。

代码解析

1. 初始化方法 __init__

def __init__(self, d_model, dropout=0.1, max_len=5000):
    super(PositionalEncoding, self).__init__()
    self.dropout = nn.Dropout(p=dropout)
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)
  • d_model:表示嵌入向量的维度,即每个位置编码的大小。
  • dropout:用于正则化,防止过拟合。
  • max_len:表示模型能够处理的最大序列长度。

1.1生成位置编码

位置编码的生成基于论文《Attention Is All You Need》中提出的方法。具体来说,位置编码是一个固定大小的向量(维度为 d_model),它通过正弦和余弦函数生成。这种编码方式能够捕捉到位置信息,并且可以处理比训练时序列长度更长的序列。

  • torch.zeros(max_len, d_model):创建一个形状为 (max_len, d_model) 的零矩阵,用于存储位置编码。
  • torch.arange(0, max_len, dtype=torch.float).unsqueeze(1):生成一个从 0 到 max_len-1 的序列,并将其转换为浮点数,然后通过 unsqueeze(1) 将其扩展为二维张量,形状为 (max_len, 1)。这表示每个置的索引。
  • torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)):计算分母部分,用于调整正弦和余弦函数的频率。torch.arange(0, d_model, 2) 表示每隔一个维度取一个值(因为 d_model 是偶数),-math.log(10000.0) / d_model 是一个缩放因子,用于控制频率的变化范围。
  • pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term):分别将正弦和余弦函数的值赋值到位置编码矩阵的偶数和奇数位置上。这样,每个位置编码的偶数维度由正弦函数生成,奇数维度由余弦函数生成。
  • pe = pe.unsqueeze(0).transpose(0, 1):将位置编码矩阵的形状调整为 (max_len, 1, d_model),然后转置为 (1, max_len, d_model),以便在后续操作中可以直接与输入张量相加。
  • self.register_buffer('pe', pe):将位置编码矩阵注册为一个缓冲区(buffer),这样它不会被视为模型的参数,但会在模型的 state_dict 中保存,并且会随着模型的移动(如从 CPU 移动到 GPU)而自动移动。
    在代码中,位置编码(Positional Encoding)的生成是基于论文《Attention Is All You Need》中提出的公式。具体公式如下:

对于每个位置 ( pos ) 和每个维度 ( i ),位置编码 ( PE(pos, i) ) 的计算公式为:

P E ( p o s , i ) = { sin ⁡ ( p o s 1000 0 2 i / d m o d e l ) if  i  is even cos ⁡ ( p o s 1000 0 2 i / d m o d e l ) if  i  is odd PE(pos, i) = \begin{cases} \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) & \text{if } i \text{ is even} \\ \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) & \text{if } i \text{ is odd} \end{cases} PE(pos,i)= sin(100002i/dmodelpos)cos(100002i/dmodelpos)if i is evenif i is odd

其中:

  • p o s pos pos 是位置索引(从 0 开始)。
  • i i i 是维度索引(从 0 开始)。
  • d m o d e l d_{model} dmodel 是嵌入向量的维度。
  • 10000 10000 10000 是一个常数,用于控制频率的变化范围。

1.2代码与公式对应关系

在代码中,位置编码的生成过程与上述公式一一对应:

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
  1. position

    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    

    这里生成了一个从 0 到 max_len-1 的序列,表示每个位置的索引 ( pos )。unsqueeze(1) 将其扩展为二维张量,形状为 (max_len, 1),方便后续的广播操作。

  2. div_term

    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    

    这里计算了公式中的分母部分:
    1 1000 0 2 i / d m o d e l \frac{1}{10000^{2i/d_{model}}} 100002i/dmodel1
    其中:

    • torch.arange(0, d_model, 2) 表示每隔一个维度取一个值(因为偶数维度用正弦,奇数维度用余弦)。
    • (-math.log(10000.0) / d_model) 是公式中的指数部分。
    • torch.exp(...) 计算 ( e^{x} ),从而得到分母的值。
  3. pe[:, 0::2]pe[:, 1::2]

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    

    这里分别计算了偶数维度和奇数维度的位置编码:

    • pe[:, 0::2] 对应公式中的正弦部分:
      P E ( p o s , i ) = sin ⁡ ( p o s 1000 0 2 i / d m o d e l ) (for even  i ) PE(pos, i) = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \quad \text{(for even } i\text{)} PE(pos,i)=sin(100002i/dmodelpos)(for even i)
    • pe[:, 1::2] 对应公式中的余弦部分:
      P E ( p o s , i ) = cos ⁡ ( p o s 1000 0 2 i / d m o d e l ) (for odd  i ) PE(pos, i) = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) \quad \text{(for odd } i\text{)} PE(pos,i)=cos(100002i/dmodelpos)(for odd i)

1.3max_len的选择

在代码中,max_len 表示模型能够处理的最大序列长度。这个值是预先设定的,用于定义位置编码矩阵的大小。在实际应用中,max_len 应该足够大,以覆盖你预期中可能遇到的最长序列。

1.3.1 代码中的 max_len
pe = torch.zeros(max_len, d_model)

这里,pe 是一个形状为 (max_len, d_model) 的零矩阵,用于存储位置编码。max_len 是这个矩阵的行数,表示可以处理的最大序列长度。

1.3.2选择 max_len 的值

在实际应用中,选择 max_len 的值需要考虑以下因素:

  1. 数据集的特性:如果你的数据集中序列的长度变化很大,那么 max_len 应该足够大,以覆盖最长的序列。
  2. 模型的容量:较大的 max_len 会增加模型的参数量,因为位置编码矩阵的大小与 max_len 直接相关。这可能会增加模型的复杂度和训练时间。
  3. 计算资源:较大的 max_len 会占用更多的内存和计算资源,特别是在使用 GPU 进行训练时。
1.3.3 实际应用

在实际应用中,max_len 的值通常根据数据集的统计特性来确定。例如,如果数据集中 99% 的序列长度都小于 100,那么可以选择 max_len = 100。这样可以确保模型能够处理大多数序列,同时避免不必要的计算开销。

1.3.4总结

max_len 是一个重要的超参数,它定义了模型能够处理的最大序列长度。选择合适的 max_len 值需要考虑数据集的特性、模型的容量和计算资源。在实际应用中,可以通过分析数据集的统计特性来确定 max_len 的值。

2. 前向传播方法 forward
def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)
  • x:输入的嵌入向量,形状为 (seq_len, batch_size, d_model)
  • self.pe[:x.size(0), :]:根据输入序列的实际长度(x.size(0)),从预定义的位置编码矩阵中取出相应长度的位置编码。
  • x + self.pe[:x.size(0), :]:将位置编码加到输入的嵌入向量上。由于位置编码的形状为 (seq_len, 1, d_model),而输入嵌入向量的形状为 (seq_len, batch_size, d_model),PyTorch 会自动进行广播操作,将位置编码应用到每个样本上。
  • self.dropout(x):对加了位置编码后的嵌入向量应用 Dropout,以防止过拟合。

完整代码:

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

总结

PositionalEncoding 类的作用是为输入的嵌入向量添加位置信息,使得 Transformer 模型能够感知序列中每个元素的位置。位置编码通过正弦和余弦函数生成,能够捕捉到位置信息,并且可以处理比训练时序列长度更长的序列。

相关文章:

  • 系统架构设计师知识小科普:系统架构评估
  • 【文献阅读】SPRec:用自我博弈打破大语言模型推荐的“同质化”困境
  • Linux上位机开发实战(qt编译之谜)
  • vue 仿deepseek前端开发一个对话界面
  • 3分钟复现 Manus 超强开源项目 OpenManus
  • 使用netlify部署github的vue/react项目或本地的dist,国内也可以正常访问
  • 人工智能混合编程实践:Python ONNX进行图像超分重建
  • PyTorch 和 Python关系
  • 先进制造aps专题三十一 免费企业高级计划和优化(Advanced Planning and Optimizer)产品FreeAPO简介
  • ELK traceId实现跨服务日志追踪
  • 【MySQL】MySQL程序解析
  • Leetcode 95-不同的二叉搜索树 II
  • Python----计算机视觉处理(opencv:像素,RGB颜色,图像的存储,opencv安装,代码展示)
  • 当量子计算遇上互联网安全:挑战与革新之路
  • Java 序列化和反序列化为什么要实现Serializable接口
  • Redis存数据就像存钱:RDB定期存款 vs AOF实时记账
  • 计算机视觉图像点运算【灰度直方图均衡化图形界面实操理解 +开源代码】
  • 深度学习 模型和代码
  • 【经验】Ubuntu|VMware 新建虚拟机后打开 SSH 服务、在主机上安装vscode并连接、配置 git 的 ssh
  • Spring Security的作用
  • 大四本科生已发14篇SCI论文?重庆大学:成立工作组核实
  • 视频丨习近平主席专机抵达莫斯科,俄战机升空护航
  • 国铁集团:铁路五一假期运输收官,多项运输指标创历史新高
  • 公积金利率降至历史最低!多项房地产利好政策落地,购房者置业成本又降了
  • 央行:增加支农支小再贷款额度3000亿元
  • 江苏淮安优化村级资源配置:淮安区多个空心村拟并入邻村