当前位置: 首页 > news >正文

pytorch小记(十四):pytorch中 nn.Embedding 详解

pytorch小记(十四):pytorch中 nn.Embedding 详解

  • PyTorch 中的 nn.Embedding 详解
    • 1. 什么是 nn.Embedding?
    • 2. nn.Embedding 的基本使用
      • 示例 1:基础用法
      • 示例 2:处理批次输入
    • 3. nn.Embedding 与 nn.Linear 的区别
      • 3.1 nn.Embedding
      • 3.2 nn.Linear
    • 4. nn.Embedding 与 nn.Sequential 的区别
    • 5. 应用场景
    • 6. 总结


PyTorch 中的 nn.Embedding 详解

在自然语言处理、推荐系统以及其他处理离散输入的任务中,我们常常需要将离散的标识符(例如单词、字符、用户 ID 等)转换为连续的、低维的向量表示。PyTorch 提供了专门的模块——nn.Embedding,用来实现这种“嵌入”操作。本文将详细解释 nn.Embedding 的工作原理、使用方法以及与普通线性层(nn.Linear)和顺序模块(nn.Sequential)的区别,并给出清晰的代码示例。


1. 什么是 nn.Embedding?

nn.Embedding 实际上是一个查找表(lookup table),它内部维护一个矩阵,每一行对应一个离散标识符的向量表示。

  • 假设你有一个词汇表,大小为 num_embeddings,每个词将映射到一个 embedding_dim 维的向量上。
  • nn.Embedding 会创建一个形状为 [num_embeddings, embedding_dim] 的矩阵。
  • 当你输入一个包含单词索引的张量时,模块会直接从这个矩阵中查找出相应行的向量,作为单词的嵌入表示。

这种方式的好处是直接“查找”而非进行繁琐的矩阵乘法计算,既高效又直观。


2. nn.Embedding 的基本使用

示例 1:基础用法

下面的例子展示了如何使用 nn.Embedding 将一组单词索引转换为对应的嵌入向量。

import torch
import torch.nn as nn

# 定义一个嵌入层
# 假设词汇表大小为 10,每个单词用 5 维向量表示
embedding = nn.Embedding(num_embeddings=10, embedding_dim=5)

# 打印嵌入矩阵的形状
print("嵌入矩阵形状:", embedding.weight.shape)
# 输出: torch.Size([10, 5])

# 定义一个包含单词索引的张量,例如 [3, 7, 1]。索引 embedding 表中[3],[7],[1]行
indices = torch.tensor([3, 7, 1])

# 使用嵌入层查找对应的嵌入向量
embedded_vectors = embedding(indices)
print("查找到的嵌入向量:")
print(embedded_vectors)

说明:

  • 输入是一个包含索引 [3, 7, 1] 的 1D 张量,输出是一个形状为 [3, 5] 的张量。
  • 每一行就是词汇表中对应索引的嵌入向量。

示例 2:处理批次输入

在实际任务中,我们通常一次处理多个样本。例如,一个批次中包含多个句子,每个句子由若干单词索引组成。下面的例子展示了如何处理批次数据。

# 假设有一个批次,包含 2 个样本,每个样本包含 4 个单词索引
'''
对应原数据的
[[[1],[2],[3],[4]],
 [[5],[6],[7],[8]]]
'''
batch_indices = torch.tensor([
    [1, 2, 3, 4],
    [5, 6, 7, 8]
])

# 使用嵌入层查找嵌入向量
batch_embeddings = embedding(batch_indices)

print("批次嵌入向量形状:", batch_embeddings.shape)
# 输出形状: torch.Size([2, 4, 5])

说明:

  • 输入的 batch_indices 形状为 [2, 4],表示 2 个样本,每个样本 4 个单词索引。
  • 输出为 [2, 4, 5],每个单词索引转换成 5 维嵌入向量。

3. nn.Embedding 与 nn.Linear 的区别

虽然 nn.Embedding 和 nn.Linear 都涉及到矩阵的操作,但二者解决的问题大不相同。

3.1 nn.Embedding

  • 用途:专门用于将离散的索引(如单词 ID)转换为连续的向量表示,是一种查找表操作。
  • 输入:通常为整数索引。
  • 输出:直接返回查找表中对应的向量,效率高,不进行额外的计算。

3.2 nn.Linear

  • 用途:用于实现线性变换,即对输入做矩阵乘法加上偏置,计算公式为 y = x W T + b y = xW^T + b y=xWT+b
  • 输入:需要连续数值的张量。
  • 应用:若要模拟嵌入操作,需要先将整数索引转换成 one-hot 编码,再通过 nn.Linear 进行计算,这样既低效又不直观。

总结

  • 使用 nn.Embedding 更直接、更高效,因为它只进行查找操作;
  • nn.Linear 则用于对连续特征进行线性变换。

4. nn.Embedding 与 nn.Sequential 的区别

  • nn.Sequential 是一个模块容器,用于按顺序组合多个层,适用于前向传播流程固定的情况。
  • nn.Embedding 则是一个具体的层,用于实现查找表功能。
  • 在模型中,我们通常将 nn.Embedding 放在最前面,将离散输入转换为连续向量,再结合 nn.Sequential 里的其他层进行进一步处理。

例如,在 NLP 模型中常常这样使用:

class TextModel(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(TextModel, self).__init__()
        # 使用 nn.Embedding 将单词索引映射为嵌入向量
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        # 使用 nn.Sequential 组合后续的线性层
        self.fc = nn.Sequential(
            nn.Linear(embed_dim, 10),
            nn.ReLU(),
            nn.Linear(10, 2)
        )
    
    def forward(self, x):
        # x 的形状可能为 (batch_size, sequence_length)
        x_embed = self.embedding(x)  # 变为 (batch_size, sequence_length, embed_dim)
        # 对嵌入向量进行池化,变为 (batch_size, embed_dim)
        x_pool = x_embed.mean(dim=1)
        out = self.fc(x_pool)
        return out

# 假设词汇表大小 100,嵌入维度 8
model = TextModel(vocab_size=100, embed_dim=8)

在这个例子中,nn.Embedding 将离散单词转换为连续向量,而 nn.Sequential 则定义了后续的前向传播步骤。


5. 应用场景

nn.Embedding 常用于:

  • 自然语言处理(NLP):将单词、子词、字符等离散输入转换为低维向量表示,为后续的 RNN、Transformer 模型提供输入。
  • 推荐系统:将用户 ID、商品 ID 映射为嵌入向量,用于捕捉用户和物品之间的相似性。
  • 图神经网络:将节点或边的离散标签转换为连续向量表示。

6. 总结

  • nn.Embedding 是一个查找表,用于将离散索引映射为连续向量。
  • 它的输入通常是整数张量,输出是对应的嵌入向量。
  • 与 nn.Linear 相比,nn.Embedding 不需要进行大量的计算,只是直接查找,所以更高效。
  • nn.Embedding 经常与 nn.Sequential 结合使用:先将离散数据转换为嵌入向量,再通过连续层进行处理。

通过以上详细解释和分步代码示例,希望大家能对 nn.Embedding 有一个全面的理解,并能在实际项目中正确使用它来提升模型的表现。

🚀 写在最后
利用 nn.Embedding,你可以轻松将离散数据转换为高质量的连续表示,这在各种深度学习任务中都是至关重要的!

相关文章:

  • 机器学习之梯度消失和梯度爆炸
  • 1.5.2 掌握Scala内建控制结构 - 块表达式
  • 【css酷炫效果】纯CSS实现虫洞穿越效果
  • Rust + WebAssembly 实现康威生命游戏
  • java 之枚举问题(超详细!!!!)
  • MySQL(索引)
  • 华为ISC+战略规划项目数字化转型驱动的智慧供应链革新(169页PPT)(文末有下载方式)
  • 架构师面试(十七):总体架构
  • numpy学习笔记4:np.arange(0, 10, 2) 的详细解释
  • 深度学习零碎知识
  • 【C语言】自定义类型:结构体
  • Android 15 获取网络切片信息的标准接口
  • 《C语言中的ASCII码表:解锁字符与数字的桥梁》
  • Netty基础—Netty实现消息推送服务
  • go语言中数组、map和切片的异同
  • Mobile-Agent-V:通过视频引导的多智体协作学习移动设备操作
  • PCDN 在去中心化互联网中的角色
  • 个人.clang-format配置,适合Linux C/C++
  • 韩顺平教育-家居网购
  • 搜广推校招面经五十四
  • 对谈|《对工作说不》,究竟是要对什么说不?
  • 龚惠民已任江西省司法厅党组书记
  • 李铁案二审今日宣判
  • 向总书记汇报具身智能发展的“稚辉君”:从期待到兴奋再到备受鼓舞
  • 北京发布今年第四轮拟供商品住宅用地清单,共计5宗22公顷
  • BNEF:亚洲、中东和非洲是电力基础设施投资的最大机会所在