当前位置: 首页 > news >正文

Transformer 代码剖析7 - 词元嵌入(TokenEmbedding) (pytorch实现)

一、类定义与继承关系剖析

1.1 代码结构图示

神经网络基础模块
词嵌入基类
自定义词元嵌入
构造函数定义
基类初始化
词汇量参数
维度参数
填充标识参数

1.2 代码实现精讲

"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""
from torch import nn

class TokenEmbedding(nn.Embedding):
    """
    基于PyTorch实现的动态词元嵌入模块
    实现词元索引到高维向量的可学习映射
    核心功能:将离散的词元序列转换为连续的语义空间表示
    """
    
    def __init__(self, vocab_size, d_model):
        """
        词元嵌入构造器
        
        :param vocab_size: 词表容量(不同词元的总数)
        :param d_model: 嵌入维度(与Transformer模型维度一致)
        设计要点:
        - 继承nn.Embedding的矩阵运算特性
        - 固化填充索引为可训练参数
        - 保持维度与模型其他组件兼容
        """
        super(TokenEmbedding, self).__init__(
            vocab_size, # 嵌入数量 num_embeddings # 嵌入矩阵行数 = 词表大小
            d_model, # 嵌入维度 embedding_dim # 嵌入矩阵列数 = 模型维度
            padding_idx=1 # 填充符索引的特殊处理
        )

二、核心参数深度解读

2.1 参数矩阵可视化

假设词表容量vocab_size=10000,模型维度d_model=512时:

参数维度元素数量数学意义
weight[10000,512]5,120,000可训练的嵌入查询矩阵
padding_idxscalar1动态掩码位置标识

2.2 关键参数说明

1. vocab_size

  • 控制嵌入矩阵的行维度
  • 决定模型可处理的词元种类上限
  • 典型值域:BERT系列(~30000),GPT系列(~50000)

2. d_model

  • 控制嵌入向量的列维度
  • 与Transformer隐藏层维度严格对齐
  • 典型值域:512(原始论文)、768(BERT-base)、1024(大型模型)

3. padding_idx

  • 实现动态序列掩码的关键参数
  • 索引位置对应的梯度会被自动抑制
  • 防止填充符影响模型语义理解

三、运算过程分步推演

3.1 前向传播示例

输入序列:[3, 28, 1, 0] (1为填充符)

运算步骤:

1. 建立索引映射:

[[3],[[0.2, -0.5, ..., 1.2],  # 索引3的嵌入
 [28],[0.7, 1.1, ..., -0.3],   # 索引28的嵌入
 [1],[0.0, 0.0, ..., 0.0],    # 填充符固定值
 [0]][-0.9, 0.4, ..., 0.1]]   # 索引0的嵌入

2. 矩阵缩放(后续处理):

embeddings * sqrt(d_model)  # 维度对齐的数学技巧

3.2 梯度传播特性

  • 可微分性: 整个映射过程保持梯度通路
  • 参数更新: 通过反向传播调整嵌入矩阵
  • 特殊处理: padding_idx位置梯度始终为0

四、设计哲学解析

4.1 继承关系价值

TokenEmbedding
torch.nn.Embedding
torch.nn.Module
PyTorch基础设施

优势分析:

  • 复用性:继承矩阵运算和参数管理功能
  • 扩展性:保留自定义前向传播的可能性
  • 兼容性:无缝对接PyTorch生态工具

4.2 工程实践建议

1. 初始化技巧:

  • 默认采用均匀分布 U ( − 1 d m o d e l , 1 d m o d e l ) U(-\sqrt{\frac{1}{d_{model}}}, \sqrt{\frac{1}{d_{model}}}) U(dmodel1 ,dmodel1 )
  • 可扩展为Xavier/Kaiming初始化:
    # Xavier均匀初始化(默认)
    nn.init.xavier_uniform_(self.weight)
    
    # 特殊处理填充符
    self.weight.data[1].zero_()
    

2. 维度对齐策略:

# 与位置编码相加前的缩放
embeddings = embeddings * math.sqrt(d_model)

3. 混合精度训练:

# 自动转换为半精度
with autocast():
    embeddings = embedding_layer(input_ids)

4. 填充符处理机制:

  • 训练阶段自动跳过无效位置的计算
  • 推理阶段维持序列形状一致性

5. 计算复杂度分析:

  • 时间复杂度: O ( B ⋅ S ⋅ D ) O(B \cdot S \cdot D) O(BSD)
  • 空间复杂度: O ( V ⋅ D ) O(V \cdot D) O(VD)

完整实现细节可参考PyTorch中sparse.py 模块解析的相关文章(嵌入(Embedding)基类代码解析)或PyTorch官方Embedding文档。

相关文章:

  • 【wordpress】服务器已有LNMP环境(已运行WordPress),如何配置文档访问功能?
  • 【笔记】用大预言模型构建专家系统
  • DeepSeek模型本地部署与应用构建
  • C++ Primer Plus第九章课后习题总结
  • 全星研发项目管理APQP软件系统:铸造芯片集成电路产业研发体系化建设平台
  • C++中的“结界”机制:作用域与变量可见性探秘
  • 【前端面试】如何不通过正则:验证IP地址合法性
  • PartitionFinder2 安装与使用-bioinfomatics tools 051
  • 从源到目标:深度学习中的迁移学习与领域自适应实践
  • 谈谈单例模式中通过Htools包的SpringUtil.getBean获取Bean的好处
  • 探索DEHP与睾酮素的隐秘关联
  • 【免费压测靶场开放】性能测试练习靶场,GET/POST双模式支持
  • SpringMVC学习(初识与复习Web程序的工作流程)(1)
  • 系统架构设计师—计算机基础篇—存储管理
  • Vim 常用快捷键大全:跳转、编辑、查找替换全解析
  • 【前端知识】Vue2.x与3.x之间的区别以及升级过程需要关注的地方
  • ​Java 加密技术全面解析:SM2、SM4、MD5 及常用加密方法​
  • Python Cookbook-2.29 带版本号的文件名
  • Java获取本机Mac地址
  • C++string类
  • 标识设计公司网站/识别关键词软件
  • 网站空间2G一年多少钱/网站域名服务器查询
  • 做电信宽带合适做网站吗/广告营销推广方案
  • 做企业网站设计价格是多少钱/襄阳seo培训
  • 织梦网站后台密码错误/怎么做好网络营销
  • jsp简述网站开发流程/建设网站费用