当前位置: 首页 > news >正文

Transformer的Word Embedding

一、Transformer 中的词嵌入是什么?

1. 定义与作用

词嵌入(Word Embedding):将离散的词语映射为低维连续向量,捕捉语义和语法信息。
在 Transformer 中的位置
• 输入层:每个词通过嵌入层转换为向量(如 embedding_dim=512)。
• 输出层:解码器输出的向量通过反向嵌入映射回词表概率(如 logits = decoder_output * embedding_matrix^T)。

2. 与 Word2Vec 的对比
特性Word2VecTransformer 中的词嵌入
上下文相关性静态(每个词仅一个向量)动态(同一词在不同上下文中向量不同)
训练方式独立预训练(无监督)端到端学习(通常结合预训练任务)
多义词处理无法区分多义词基于上下文动态调整(如 BERT)
位置信息通过位置编码(Positional Encoding)
参数规模较小(仅词表大小 × 嵌入维度)较大(嵌入层是模型的一部分)

二、Transformer 词嵌入的核心革新

1. 上下文相关(Contextualized Embeddings)

问题:Word2Vec 的静态词向量无法处理一词多义(例如“bank”在“river bank”和“bank account”中的不同含义)。
解决方案:Transformer 通过自注意力机制动态调整词向量:
• 输入序列中的每个词向量在编码过程中与其他词交互,生成上下文相关的表示。
示例:在句子 Apple launched a new phone 中,“Apple”的向量会包含“phone”的语义;而在 Apple pie is delicious 中,“Apple”的向量会包含“pie”的语义。

2. 位置编码(Positional Encoding)

问题:Transformer 抛弃了 RNN 的时序结构,需显式注入位置信息。
实现方式
绝对位置编码:通过正弦函数或可学习向量编码词的位置(原始论文方法):
P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d model ) PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{\text{model}}}) PE(pos,2i)=sin(pos/100002i/dmodel) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d model ) PE_{(pos, 2i+1)} = \cos(pos10000^{2i/d_{\text{model}}}) PE(pos,2i+1)=cos(pos100002i/dmodel)
相对位置编码:某些变体(如 Transformer-XL)编码词之间的相对距离。

3. 预训练任务驱动

预训练任务:Transformer 的词嵌入通常通过大规模预训练任务学习:
BERT:掩码语言模型(Masked Language Model, MLM) + 下一句预测(Next Sentence Prediction, NSP)。
GPT:自回归语言模型(预测下一个词)。
优势
• 词嵌入不仅包含通用语义,还编码了任务相关的知识(如句间关系、长程依赖)。


三、Transformer 词嵌入的技术细节

1. 嵌入层的数学表示

• 给定词表大小为 V V V,嵌入维度为 d d d,嵌入层是一个 V × d V \times d V×d的矩阵。
• 输入序列 [ w 1 , w 2 , . . . , w n ] [w_1, w_2, ..., w_n] [w1,w2,...,wn] 经过嵌入层后得到矩阵 X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d,再与位置编码 P P P 相加:
X final = X + P X_{\text{final}} = X + P Xfinal=X+P

2. 与自注意力的交互

• 自注意力机制通过查询(Query)、键(Key)、值(Value)矩阵对词向量进行交互:
Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QK)V
结果:每个词的输出向量是所有词向量的加权和,权重由语义相关性决定。

3. 跨层信息传递

• Transformer 的每一层(Layer)都会更新词向量:
• 底层编码局部语法(如词性)。
• 高层编码全局语义(如指代消解、情感倾向)。


四、实例分析:BERT 的嵌入层

1. 输入表示

BERT 的输入嵌入由三部分组成:

  1. 词嵌入(Token Embeddings):将词语映射为向量。
  2. 位置嵌入(Position Embeddings):可学习的位置编码。
  3. 段嵌入(Segment Embeddings):区分句子对(如问答任务中的问题和答案)。
2. 掩码语言模型(MLM)

训练任务:随机遮盖输入中的某些词(如替换为 [MASK]),让模型预测被遮盖的词。
对词嵌入的影响
• 迫使模型通过上下文推断被遮盖词,增强嵌入的上下文敏感性。
• 示例:在句子 The [MASK] sat on the mat 中,模型需根据 satmat 推断 [MASK] 可能是 cat

3. 输出示例

• 输入词 bank 在不同上下文中的 BERT 嵌入向量:
• 上下文 1:river bank → 向量靠近 shore, water
• 上下文 2:bank account → 向量靠近 money, finance


五、代码示例:Transformer 嵌入层的实现(PyTorch)

import torch
import torch.nn as nn

class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len, dropout=0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_seq_len, embed_dim)  # 可学习的位置编码
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: [batch_size, seq_len]
        positions = torch.arange(x.size(1), device=x.device).unsqueeze(0)
        token_emb = self.token_embed(x)        # [batch_size, seq_len, embed_dim]
        pos_emb = self.pos_embed(positions)    # [1, seq_len, embed_dim]
        return self.dropout(token_emb + pos_emb)

# 使用示例
vocab_size = 10000
embed_dim = 512
max_seq_len = 128
model = TransformerEmbedding(vocab_size, embed_dim, max_seq_len)

input_ids = torch.randint(0, vocab_size, (32, max_seq_len))  # 模拟输入(batch_size=32)
output_emb = model(input_ids)  # [32, 128, 512]

六、总结:Transformer 的词嵌入

  1. 上下文动态调整:通过自注意力机制捕捉长距离依赖,解决一词多义。
  2. 预训练赋能:在大规模语料上预训练,使词嵌入包含丰富的世界知识。
  3. 位置感知:显式编码位置信息,弥补无时序结构的缺陷。
  4. 端到端学习:嵌入层与模型其他部分联合优化,适应具体任务需求。

相关文章:

  • Spring Boot 项目集成 License 授权与续期完整指南
  • GS+:地统计分析与空间插值工具
  • 【区块链安全 | 第三十五篇】溢出漏洞
  • HackMyVM-Preload
  • SSRF漏洞利用的小点总结和实战演练
  • 内存池项目(2)——内存池设计之边界标识法
  • File 类的用法和 InputStream, OutputStream 的用法
  • 【虚拟化安全】虚拟化安全知识全攻略:保障云端数据安全
  • 数据库设计工具drawDB本地部署与远程在线协作实测让效率翻倍
  • Hibernate核心方法总结
  • 阿里云oss视频苹果端无法播放问题记录
  • 项目二 - 任务5:打印乘法九九表
  • Qt饼状图在图例上追踪鼠标落点
  • 人脸表情识别数据集分享(AffectNet、RAF-DB、FERPlus、FER2013、ck+)
  • NVIDIA Jetson 环境安装指导 PyTorch | Conda | cudnn | docker
  • 【qiankun】简易前端微应用搭建
  • 企业工厂生产线马达保护装置 功能参数介绍
  • 4.6学习总结
  • 网络中级(HCIP)项目实践一MGRE的两种架构的私有网段 OSPF 动态路由协议的互联实验(手把手教您,包学会的)
  • 使用 STM32F103C8 连接 ESP8266:创建 Web 服务器
  • 网页站点怎么命名/建立网站需要什么
  • 施工企业的安全生产管理机构以及安全生产管理人员履行下列职责:( )/南京seo整站优化技术
  • 俄文网站建设 俄文网站设计/网站收录查询工具
  • 福建建设厅网站工程履约保险/外贸seo优化公司
  • 2015做网站前景/网站关键词快速排名技术
  • 深圳做手机网站设计/北京seo网站优化培训