当前位置: 首页 > news >正文

文本嵌入层

1、代码演示

embedding = nn.Embedding(10,3)
print(embedding)
input = torch.LongTensor([[1,2,3,4],[4,3,2,9]])
embedding(input)

2、构建Embeddings类来实现文本嵌入层

# 构建Embedding类来实现文本嵌入层
class Embeddings(nn.Module):
    def __init__(self,d_model,vocab):
        """
        :param d_model: 词嵌入的维度
        :param vocab: 词表的大小
        """
        super(Embeddings,self).__init__()
        self.lut = nn.Embedding(vocab,d_model)
        self.d_model = d_model
    def forward(self,x):
        """
        :param x: 因为Embedding层是首层,所以代表输入给模型的文本通过词汇映射后的张量
        :return:
        """
        return self.lut(x) * math.sqrt(self.d_model)
x = Variable(torch.LongTensor([[100,2,42,508],[491,998,1,221]]))
emb = Embeddings(512,1000)
embr = emb(x)
print(embr.shape)             # torch.Size([2, 4, 512])
print(embr)
print(embr[0][0].shape)       # torch.Size([512])

http://www.dtcms.com/a/1829.html

相关文章:

  • Qt raise()问题
  • 【QT】使用toBase64方法将.txt文件的明文变为非明文(类似加密)
  • Mysql生产随笔
  • vue下载在前端存放的pdf文件
  • 玩碎Java之CompletableFuture的例子
  • Java初始化大量数据到Neo4j中(二)
  • lambda的使用案例(1)
  • 探索视听新纪元: ChatGPT的最新语音和图像功能全解析
  • Flutter笔记:AnimationMean、AnimationMax 和 AnimationMin 三个类的用法
  • 朴素贝叶斯分类(下):数据挖掘十大算法之一
  • 了解ActiveMQ、RabbitMQ、RocketMQ和Kafka的特点
  • 嵌入式开源库之libmodbus学习笔记
  • 27、Flink 的SQL之SELECT (Pattern Recognition 模式检测)介绍及详细示例(7)
  • Linux网络编程- struct ifreq ioctl() 系统调用
  • Android 13 - Media框架(8)- MediaExtractor(2)
  • 机器学习第十四课--神经网络
  • stream对list数据进行多字段去重
  • 问答区混赏金的集合贴
  • 华为杯数学建模比赛经验分享
  • $nextTick解决echarts宽度固定为100%的问题
  • Armv9 Cortex-A720的L2 memory system 和 L2 Cache
  • Leetcode 297. 二叉树的序列化与反序列化
  • 【LeetCode】滑动窗口妙解无重复字符的最长子串
  • 华为智能高校出口安全解决方案(2)
  • Ubuntu Qt 5.15.2 支持 aarch64
  • 【李沐深度学习笔记】损失函数
  • C++与数据结构面经(重中之重)
  • 83、SpringBoot --- 下载和安装 MSYS2、 Redis
  • 【ARMv8 SIMD和浮点指令编程】NEON 加载指令——如何将数据从内存搬到寄存器(LDxLDxR)?
  • 数据响应式原理