当前位置: 首页 > news >正文

Transformer-输入部分

Transformer输入部分实现

一、学习目标

  • 了解文本嵌入层和位置编码的作用

  • 掌握文本嵌入层和位置编码的实现过程

二、输入部分介绍

Transformer的输入部分包含两个主要组件:

  • 源文本嵌入层及其位置编码器

  • 目标文本嵌入层及其位置编码器

这些组件共同作用,将文本中的词汇数字表示转换为包含位置信息的向量表示。

三、文本嵌入层(Embeddings)

3.1 作用

文本嵌入层的主要作用是将文本中词汇的数字表示转变为向量表示,在高维空间中捕捉词汇间的关系。无论是源文本嵌入还是目标文本嵌入,都遵循相同的原理。

3.2 实现代码分析

import torch
import torch.nn as nn
import math
from torch.autograd import Variableclass Embeddings(nn.Module):def __init__(self, d_model, vocab):"""参数说明:- d_model: 每个词汇的特征尺寸(词嵌入维度)- vocab: 词汇表大小"""super(Embeddings, self).__init__()self.d_model = d_modelself.vocab = vocab# 定义词嵌入层self.lut = nn.Embedding(self.vocab, self.d_model)def forward(self, x):# 将x传给self.lut并与根号下self.d_model相乘作为结果返回return self.lut(x) * math.sqrt(self.d_model)

3.3 关键点说明

  • ​缩放因子​​:乘以math.sqrt(self.d_model)是为了增大x的值,使得词嵌入后的embedding vector与位置编码信息的量纲相近

  • ​nn.Embedding​​:PyTorch内置的嵌入层,将索引映射为密集向量

3.4 使用示例

def test_Embeddings():d_model = 512   # 词嵌入维度是512维vocab = 1000    # 词表大小是1000# 实例化词嵌入层my_embeddings = Embeddings(d_model, vocab)# 输入数据:2个句子,每个句子4个词汇x = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))embed = my_embeddings(x)print('embed.shape:', embed.shape)  # 输出: torch.Size([2, 4, 512])

四、位置编码器(PositionalEncoding)

4.1 作用

由于Transformer的编码器结构没有专门处理词汇位置信息,需要在Embedding层后加入位置编码器,将词汇位置信息加入到词嵌入张量中,以弥补位置信息的缺失。

4.2 实现代码分析

class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout, max_len=5000):"""参数说明:- d_model: 词嵌入维度- dropout: 丢弃率- max_len: 最大序列长度"""super(PositionalEncoding, self).__init__()# 定义dropout层self.dropout = nn.Dropout(p=dropout)# 初始化位置编码矩阵pe = torch.zeros(max_len, d_model)# 创建位置序列 [0, 1, 2, ..., max_len-1]position = torch.arange(0, max_len).unsqueeze(1)# 计算变化矩阵(用于位置编码的缩放)div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))# 应用正弦和余弦函数到奇偶数列pe[:, 0::2] = torch.sin(position * div_term)  # 偶数列使用正弦pe[:, 1::2] = torch.cos(position * div_term)  # 奇数列使用余弦# 调整形状并注册为缓冲区(不参与训练)pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):# 将位置编码添加到输入中x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)return self.dropout(x)

4.3 关键点说明

  • ​正弦余弦编码​​:使用不同频率的正弦和余弦函数为不同位置生成独特的编码

  • ​缓冲区注册​​:使用register_buffer将位置编码矩阵注册为模型缓冲区,不参与训练但会随模型保存/加载

  • ​位置信息添加​​:通过简单的加法将位置编码信息融合到词嵌入中

4.4 使用示例

def test_PositionalEncoding():d_model = 512vocab = 1000# 创建词嵌入层和位置编码器embeddings = Embeddings(d_model, vocab)positional_encoding = PositionalEncoding(d_model, dropout=0.1, max_len=60)# 处理输入数据x = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))embed = embeddings(x)pe_result = positional_encoding(embed)print('pe_result.shape:', pe_result.shape)  # 输出: torch.Size([2, 4, 512])

五、位置编码特征可视化

5.1 可视化代码

import matplotlib.pyplot as plt
import numpy as npdef draw_PE_feature():# 创建位置编码器(小维度便于可视化)my_pe = PositionalEncoding(d_model=20, dropout=0)# 创建输入数据y = my_pe(Variable(torch.zeros(1, 100, 20)))# 绘制特征曲线plt.figure(figsize=(20, 20))plt.plot(np.arange(100), y[0, :, 4:8].numpy())plt.legend(["dim %d" % p for p in [4, 5, 6, 7]])plt.show()

5.2 可视化结果分析

  • ​曲线特征​​:每条颜色的曲线代表某一个词汇特征在不同位置的含义

  • ​位置敏感性​​:保证同一词汇随着所在位置不同,其对应位置嵌入向量会发生变化

  • ​数值控制​​:正弦波和余弦波的值域范围为[-1, 1],有助于控制嵌入数值大小和梯度计算

六、总结

6.1 文本嵌入层

  • ​作用​​:将词汇数字表示转换为向量表示,捕捉词汇间关系

  • ​实现要点​​:

    • 使用nn.Embedding进行词嵌入

    • 通过乘以sqrt(d_model)进行数值缩放

    • 输出形状为[batch_size, seq_len, d_model]

6.2 位置编码器

  • ​作用​​:为模型提供词汇位置信息,弥补Transformer架构的位置不敏感性

  • ​实现要点​​:

    • 使用正弦余弦函数生成位置编码

    • 奇偶数列分别使用正弦和余弦函数

    • 通过加法将位置信息融合到词嵌入中

6.3 整体流程

  1. 输入词汇索引 → 文本嵌入层 → 得到词向量

  2. 词向量 + 位置编码 → 得到包含位置信息的词表示

  3. 通过Dropout进行正则化处理

通过这种设计,Transformer的输入部分能够有效地将词汇的语义信息和位置信息结合起来,为后续的编码器和解码器提供丰富的输入表示。

http://www.dtcms.com/a/483239.html

相关文章:

  • Python接口与抽象基类详解:从规范定义到高级应用
  • 免费网站建设价格费用.net做网站用什么的多
  • 专业高端网站建设服务公司百度指数趋势
  • AI商品换模特及场景智能化
  • 网站开发定制推广杭州视频在线生成链接
  • 异步任务使用场景与实践
  • 300多个Html5小游戏列表和下载地址
  • 企业门户网站方案建网站有报价单吗
  • 企业网站开发价钱低免费开个人网店
  • 建网站软件下载那个软件可以做三个视频网站
  • Excel使用教程笔记
  • 论文阅读《LIMA:Less Is More for Alignment》
  • wordpress 网站暂停app建设网站
  • 考研408--组成原理--day1
  • 网络公司构建网站杭州旅游团购网站建设
  • 【数值分析】非线性方程与方程组的数值解法的经典算法(附MATLAB代码)
  • 文件外链网站智慧团建官网登录入口电脑版
  • 如何在Windows上为Java配置多个版本的环境变量
  • 如何将自己做的网站放到网上去如何做电商创业
  • 杭州市建设信用网郑州优化网站关键词
  • 农业与供应链类 RWA 落地研究报告
  • p2p理财网站开发cms和wordpress
  • 合肥seo整站优化网站做跳转付款
  • 物联网的调试
  • React项目开发(代码架构/规范怎么做)?
  • 做视频网站要准备哪些资料广告设计与制作好找工作吗
  • 双token登录
  • [Backstage] 认证请求的流程 | JWT令牌
  • 简述网站规划的一般步骤马鞍山集团网站设计
  • 使用 Rufus 制作启动盘安装 Windows 与 Ubuntu 系统全流程教程(图文详解+避坑指南)