当前位置: 首页 > news >正文

带有 Attention 机制的 Encoder-Decoder 架构模型分析

前言

模型是一个带有 注意力机制(Attention) 的编码器-解码器(Encoder-Decoder)架构。

  • 编码器 (Encoder): 负责读取并理解输入的新闻原文。它将原文压缩成一系列的“记忆”向量(encoder_outputs)和一个最终的“思想总结”向量(encoder_hidden)。
  • 解码器 (Decoder): 负责逐词生成摘要。在生成每个词时,它会参考编码器的“思想总结”,并利用注意力机制来“查看”原文中最相关的部分。
  • 注意力机制 (Attention): 这是连接编码器和解码器的桥梁。它允许解码器在生成摘要的每一步,都能动态地决定应该关注原文中的哪些词语,这对于长文本的摘要任务至关重要。

示例数据

为了方便理解,我们用一个极简的例子来贯穿整个流程,示例的目的是将新闻原文提取出摘要。

  • 源序列 (新闻原文 src): "The cat sat on the mat"
  • 目标序列 (摘要 tgt): "<sos> A cat rested <eos>" (这里的 <sos><eos> 分别是序列开始和结束的特殊标记)

假设经过分词和词典映射后,它们变成了数字ID:

  • src_tensor: [2, 3, 4, 5, 6, 7] (长度 src_len = 6)
  • tgt_tensor: [0, 8, 3, 9, 1] (长度 tgt_len = 5)

我们假设 batch_size = 1,即一次只处理一个样本。


1. 编码器 (Encoder)

目标: 阅读源序列 "The cat sat on the mat",并生成其语义表示。

  1. 输入 (src_tensor):

    • 形状: [src_len, batch_size] -> [6, 1]
  2. 词嵌入 (self.embedding):

    • 作用: 将每个词的数字ID转换为一个密集、连续的向量(词嵌入)。embedding_dim 参数定义了每个词向量的维度。
    • 过程: [2, 3, 4, 5, 6, 7] -> 一系列向量。
    • 形状变化: [6, 1] -> [6, 1, embedding_dim]
  3. 双向GRU (self.rnn):

    • 作用: 按顺序处理词向量序列,捕捉上下文信息。hidden_dim 是GRU隐藏状态的维度。num_layers 是GRU的层数。因为是双向的,它会从前到后和从后到前各扫描一遍序列。
    • 输出:
      • outputs: 包含了每个时间步(即每个词)的前向和后向隐藏状态的拼接。这可以看作是编码器对整个原文的“记忆”。
        • 形状: [src_len, batch_size, hidden_dim * 2] -> [6, 1, hidden_dim * 2]*2 是因为拼接了前向和后向的隐藏状态。
      • hidden: 最后时间步的隐藏状态。
        • 形状: [num_layers * 2, batch_size, hidden_dim] -> [2, 1, hidden_dim] (假设num_layers=1)。第0维是前向的最终状态,第1维是后向的最终状态。
  4. 隐藏状态处理:

    • 代码: hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1).unsqueeze(0)
    • 作用: 将双向GRU的两个最终隐藏状态(最后一个前向状态和最后一个后向状态)拼接起来,形成一个单一的上下文向量,作为对整个源句子的“思想总结”。
    • 过程:
      • hidden[-2,:,:]: 取出最后一个前向隐藏状态。形状: [1, hidden_dim]
      • hidden[-1,:,:]: 取出最后一个后向隐藏状态。形状: [1, hidden_dim]
      • torch.cat(...): 沿着维度1(特征维度)拼接。形状变为 [1, hidden_dim * 2]
      • .unsqueeze(0): 在最前面增加一个维度,以匹配解码器期望的输入格式。最终形状: [1, 1, hidden_dim * 2]

编码器完成! 我们得到了两个关键输出:

  • encoder_outputs: 形状 [6, 1, hidden_dim * 2] (对原文每个词的详细记忆)。
  • encoder_hidden: 形状 [1, 1, hidden_dim * 2] (对原文的整体总结)。

2. Seq2Seq 模块 (连接编码器和解码器)

在进入解码器之前,Seq2Seq模块会做一些准备工作。

  1. 获取编码器输出:

    • encoder_outputs, encoder_hidden = self.encoder(src)
  2. 转换隐藏状态:

    • 代码: decoder_hidden = torch.tanh(self.encoder_to_decoder_hidden(encoder_hidden.squeeze(0))).unsqueeze(0)
    • 作用: 编码器的最终隐藏状态维度是 hidden_dim * 2,但解码器的GRU是单向的,它期望的隐藏状态维度是 hidden_dim。这个全连接层 (encoder_to_decoder_hidden) 就是一个适配器,将编码器的总结向量转换为解码器可以理解的初始状态。
    • 过程:
      • encoder_hidden.squeeze(0): 形状从 [1, 1, hidden_dim * 2] -> [1, hidden_dim * 2]
      • self.encoder_to_decoder_hidden(...): 线性变换。形状 [1, hidden_dim * 2] -> [1, hidden_dim]
      • torch.tanh(...): 应用激活函数,增加非线性。
      • .unsqueeze(0): 加回维度,以匹配解码器GRU的输入。形状 [1, hidden_dim] -> [1, 1, hidden_dim]

准备完成! 我们得到了解码器的初始隐藏状态 decoder_hidden


3. 解码器 (Decoder) 与 注意力 (Attention)

目标: 生成摘要 <sos> A cat rested <eos>。解码器是自回归的,即生成一个词后,再把这个词作为输入来生成下一个词。

这个过程在一个循环中进行,我们以生成第一个词 "A" 为例。

  1. 初始输入:

    • decoder_input: <sos> 标记的ID。形状 [1] (因为batch_size=1)。
    • decoder_hidden: 上一步准备好的初始隐藏状态。形状 [1, 1, hidden_dim]
    • encoder_outputs: 编码器的记忆。形状 [6, 1, hidden_dim * 2]
  2. 词嵌入 (self.embedding):

    • <sos> ID 被转换为词向量。
    • 形状变化: [1] -> [1, 1, embedding_dim]
  3. 计算注意力权重 (self.attention): 这是核心步骤

    • 目标: 决定在生成当前词时,应该关注源序列 ("The cat sat on the mat")中的哪个词。
    • 过程:
      a. decoder_hidden (形状 [1, 1, hidden_dim]) 和 encoder_outputs (形状 [6, 1, hidden_dim * 2]) 被送入注意力模块。
      b. decoder_hidden 被复制6次 (src_len次),与 encoder_outputs 的每个时间步对齐。
      c. 将复制后的 decoder_hiddenencoder_outputs 拼接,然后通过一个线性层 (self.attn) 计算一个“对齐分数”或“能量值”。这衡量了解码器当前状态与原文每个词的匹配程度。
      d. 这些分数通过softmax函数转换为概率,即注意力权重 attention_weights
    • 输出 (attention_weights):
      • 形状: [batch_size, src_len] -> [1, 6]
      • 例子: 可能得到 [0.1, 0.6, 0.1, 0.05, 0.1, 0.05]。这意味着在生成第一个词时,模型认为原文中的 "cat" (权重0.6) 是最重要的参考信息。
  4. 计算上下文向量 (context):

    • 作用: 根据注意力权重,对编码器的输出 encoder_outputs 进行加权求和。
    • 过程: context = attention_weights.bmm(encoder_outputs) (简化表示)。这本质上是一个加权平均。
    • 输出 (context):
      • 形状: [batch_size, hidden_dim * 2] -> [1, hidden_dim * 2]
      • 这个向量融合了原文中此刻最重要的信息。由于 "cat" 的权重最高,context 向量会很像 encoder_outputs 中代表 "cat" 的那个向量。
  5. 准备GRU输入 (rnn_input):

    • 代码: rnn_input = torch.cat((embedded.squeeze(0), context), dim=1).unsqueeze(0)
    • 作用: 将当前输入词的信息 (embedded) 和从原文中提取的上下文信息 (context) 结合起来,作为解码器GRU的输入。
    • 过程:
      • embedded.squeeze(0): 形状 [1, 1, embedding_dim] -> [1, embedding_dim]
      • torch.cat(...): 拼接后形状为 [1, embedding_dim + hidden_dim * 2]
      • .unsqueeze(0): 加回时间步维度,形状 [1, 1, embedding_dim + hidden_dim * 2]
  6. 解码器GRU (self.rnn):

    • 作用: 接收 rnn_input 和上一个隐藏状态 decoder_hidden,更新状态并产生一个输出。
    • 输出:
      • output: GRU的输出。形状 [1, 1, hidden_dim]
      • hidden: 新的隐藏状态,将用于生成下一个词。形状 [1, 1, hidden_dim]
  7. 最终预测 (self.out):

    • 作用: 将GRU的输出 output 通过一个全连接层,映射到整个词汇表的大小 (vocab_size)。
    • 输出 (prediction):
      • 形状: [batch_size, vocab_size] -> [1, vocab_size]
      • 这是一个分数向量,每个分数对应词汇表中的一个词。分数最高的那个词就是模型当前的预测结果。假设ID为8的词 "A" 分数最高。

4. 循环与 Teacher Forcing

  • 存储预测: 上一步得到的 prediction (形状[1, vocab_size]) 被存入 outputs 张量。
  • 决定下一个输入:
    • Teacher Forcing: 在训练时,我们有一定概率(由teacher_forcing_ratio控制)直接使用真实的目标词tgt[t],即 "cat")作为下一个时间步的输入。这有助于模型更快地学习。
    • 无 Teacher Forcing (推理时): 我们使用模型自己刚刚预测出的词top1,即 "A")作为下一个时间步的输入。
  • 循环: 这个过程(从第3步开始)会一直重复,直到生成摘要中的所有词。在每一步,新的decoder_hidden和新的decoder_input都会被用来生成下一个词,并且注意力机制会重新计算,关注原文中不同的部分。例如,在生成 “rested” 时,注意力权重可能会更高地集中在原文的 “sat on the mat” 上。

总结

整个流程就像一个模拟人类翻译或总结的过程:

  1. 编码器完整地读一遍原文,形成一个整体印象和对每个词的记忆。
  2. 解码器开始写摘要。在写第一个词时,它会回想原文的整体印象。
  3. 注意力机制帮助解码器在写每个词时,回头去“看”原文中最相关的几个词,然后把这些信息和它正要写的词结合起来。
  4. 解码器写完一个词,更新自己的“思路”(隐藏状态),然后继续写下一个词,周而复始,直到写完整个摘要。

文章转载自:

http://zZTu25eY.bgzgq.cn
http://5Qn17Djc.bgzgq.cn
http://i6Bs5BSM.bgzgq.cn
http://sbYkDeNw.bgzgq.cn
http://RCZ38MT4.bgzgq.cn
http://3nly9owE.bgzgq.cn
http://dBBNzaKg.bgzgq.cn
http://oWveiUcm.bgzgq.cn
http://VH2iP9Pd.bgzgq.cn
http://DSaN2ktI.bgzgq.cn
http://KSR79nUp.bgzgq.cn
http://VTpAxFcv.bgzgq.cn
http://1wVu2C1Z.bgzgq.cn
http://y0I6vMGE.bgzgq.cn
http://FAl1m0NT.bgzgq.cn
http://PZCKmQPp.bgzgq.cn
http://Iueiom97.bgzgq.cn
http://TdBOactz.bgzgq.cn
http://dg3OXocm.bgzgq.cn
http://L6D8jiEp.bgzgq.cn
http://7mBnL0kp.bgzgq.cn
http://bkKAx2R5.bgzgq.cn
http://DvzKs2aU.bgzgq.cn
http://mwR1dA7E.bgzgq.cn
http://OBRJG2Tj.bgzgq.cn
http://dIFDcTTC.bgzgq.cn
http://jp6sbl8t.bgzgq.cn
http://4ccewvy4.bgzgq.cn
http://HL0PKdQ2.bgzgq.cn
http://mmrOANrx.bgzgq.cn
http://www.dtcms.com/a/374382.html

相关文章:

  • 利用易语言编写,逻辑为按照数字越大抽取率越前
  • leetcode 219 存在重复元素II
  • Redis(缓存)
  • ARP 协议
  • 169.在Vue3中使用OpenLayers + D3实现地图区块呈现不同颜色的效果
  • 【C++】递归与迭代:两种编程范式的对比与实践
  • 【Java】设计模式——单例、工厂、代理模式
  • C++ ——一文读懂:Valgrind 检测内存泄漏
  • 代码随想录算法训练营第三十一天 | 合并区间、单调递增的数字
  • Redis核心通用命令深度解析:结合C++ redis-plus-plus 实战指南
  • 三防手机的三防是指什么?推荐一款实用机型
  • 请求库-axios
  • Python 2025:AI工程化与智能代理开发实战
  • 聚铭网络入选数世咨询《中国数字安全价值图谱》“日志审计”推荐企业
  • 【56页PPT】数字化智能工厂总体设计SRMWCSWMSMESEMS系统建设方案(附下载方式)
  • 高性价比云手机挑选指南
  • 分布式IP代理集群架构与智能调度系统
  • 构造函数和析构函数中的多态陷阱:C++的隐秘角落
  • 使用 Altair RapidMiner 将机器学习引入您的 Mendix 应用程序
  • 从IFA再出发:中国制造与海信三筒洗衣机的“答案”
  • SQLite 数据库核心知识与 C 语言编程
  • unity中通过拖拽,自定义scroll view中子物体顺序
  • 最长上升子序列的长度最短连续字段和(动态规划)
  • 2025年最新AI大模型原理和应用面试题
  • Docker 轻量级管理Portainer
  • Aider AI Coding 智能上下文管理深度分析
  • 【Vue3】02-Vue3工程目录分析
  • JavaSE 集合从入门到面试:全面解析与实战指南
  • 《AI大模型应知应会100篇》第70篇:大模型驱动的自动化工具开发(国产化实战版)
  • 电机控制(四)-级联PID控制器与参数整定(MATLABSimulink)