当前位置: 首页 > news >正文

机器翻译:Bahdanau注意力和Luong注意力详解

文章目录

    • 一、 Bahdanau 注意力 (Additive / Concat Attention)
      • 1.1 核心思想
      • 1.2 计算步骤
      • 1.3 特点总结
    • 二、 Luong 注意力 (General / Multiplicative Attention)
      • 2.1 核心思想
      • 2.2 计算步骤
      • 2.3 特点总结
    • 三、 对比与选择
      • 3.1 两种注意力对比
      • 3.2 两种注意力的优缺点
      • 3.3 如何选择?

注意力机制(Attention Mechanism) 是自然语言处理(NLP)和深度学习中的核心技术之一,尤其在序列到序列(Seq2Seq)任务(如机器翻译、文本摘要等)中表现突出。Bahdanau注意力(又称“加性注意力”)和Luong注意力(又称“乘性注意力”)是两种经典的注意力模型,它们在计算方式和应用场景上有所不同。

一、 Bahdanau 注意力 (Additive / Concat Attention)

Bahdanau注意力是最早被提出的注意力机制之一,由Dzmitry Bahdanau等人在2014年的论文《Neural Machine Translation by Jointly Learning to Align and Translate》中提出。

1.1 核心思想

它通过一个可学习的神经网络来计算查询和键之间的“兼容性分数”。这个网络将查询和键拼接起来,然后通过一个带有tanh激活函数的前馈网络进行处理。因此,它也被称为加性注意力拼接注意力

1.2 计算步骤

假设我们有:

  • 输入序列: X=(x1,x2,...,xT)X = (x_1, x_2, ..., x_T)X=(x1,x2,...,xT)
  • 编码器隐藏状态: H=(h1,h2,...,hT)H = (h_1, h_2, ..., h_T)H=(h1,h2,...,hT), 其中 hih_ihi 是输入第 iii 个词后的隐藏状态。
  • 解码器在时间步 ttt 的隐藏状态: sts_tst

Bahdanau注意力的计算流程如下:
步骤一:计算注意力分数
对于解码器的每一个状态 sts_tst,我们需要计算它与编码器每一个隐藏状态 hih_ihi 的相关性分数 etie_{ti}eti。这个分数衡量了输入序列中第 iii 个词对生成输出序列第 ttt 个词有多重要。
公式:
eti=score(st−1,hi)=vaTtanh⁡(Wast−1+Uahi)e_{ti} = \text{score}(s_{t-1}, h_i) = v_a^T \tanh(W_a s_{t-1} + U_a h_i) eti=score(st1,hi)=vaTtanh(Wast1+Uahi)

  • st−1s_{t-1}st1: 解码器在前一个时间步的隐藏状态(作为Query)。
  • hih_ihi: 编码器在第 iii 个时间步的隐藏状态(作为KeyValue)。
  • Wa,UaW_a, U_aWa,Ua: 可学习的权重矩阵。
  • vav_ava: 可学习的权重向量。
  • tanh⁡\tanhtanh: 激活函数。
  • 解释:这个公式将 st−1s_{t-1}st1hih_ihi 线性变换后拼接,再通过tanh函数,最后与向量 vav_ava 做点积,得到一个标量分数 etie_{ti}eti

步骤二:分数归一化(计算注意力权重)
将所有分数通过Softmax函数进行归一化,得到注意力权重 αti\alpha_{ti}αti。这些权重之和为1,表示每个输入词被关注的“概率”。
公式:
αti=exp⁡(eti)∑j=1Texp⁡(etj)\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_{j=1}^{T} \exp(e_{tj})} αti=j=1Texp(etj)exp(eti)

步骤三:计算上下文向量
使用注意力权重对编码器的隐藏状态进行加权求和,得到当前时间步的上下文向量 ctc_tct
公式:
ct=∑i=1Tαtihic_t = \sum_{i=1}^{T} \alpha_{ti} h_i ct=i=1Tαtihi

步骤四:生成输出
将上下文向量 ctc_tct 与解码器的上一隐藏状态 st−1s_{t-1}st1 拼接,然后输入到解码器网络(通常是另一个RNN)中,生成当前时间步的隐藏状态 sts_tst 和输出词 y^t\hat{y}_ty^t
公式:
st=RNN_decoder(concat(ct,st−1))s_t = \text{RNN\_decoder}(\text{concat}(c_t, s_{t-1})) st=RNN_decoder(concat(ct,st1))
y^t=softmax(Wsst+bs)\hat{y}_t = \text{softmax}(W_s s_t + b_s) y^t=softmax(Wsst+bs)

1.3 特点总结

  • Query: 解码器的前一个隐藏状态 st−1s_{t-1}st1
  • Key/Value: 编码器的隐藏状态 hih_ihi
  • 打分函数: 一个前馈神经网络,计算成本相对较高。
  • 对齐: Bahdanau注意力是双向的,解码器在时间步 ttt 可以“看到”自己已经生成的所有词(通过 st−1s_{t-1}st1)以及整个输入序列。

二、 Luong 注意力 (General / Multiplicative Attention)

Luong注意力由Thang Luong等人在2015年的论文《Effective Approaches to Attention-based Neural Machine Translation》中提出。它对Bahdanau注意力进行了简化和改进。也称为 乘性注意力(Multiplicative Attention)

2.1 核心思想

Luong注意力认为,计算查询和键的兼容性分数可以通过更简单的方式实现,比如直接使用向量间的点积。这使得计算更加高效。

2.2 计算步骤

与Bahdanau类似,但其打分函数和使用的Query不同。

步骤一:计算注意力分数
Luong注意力提供了几种不同的打分方法,其中最常用的是GeneralDot

  • General (点积变体):
    eti=stTWhhie_{ti} = s_t^T W_h h_i eti=stTWhhi
    • 与Bahdanau最大的不同是,这里使用的是当前解码器隐藏状态 sts_tst 作为Query,而不是前一个状态 st−1s_{t-1}st1
    • WhW_hWh 是一个可学习的权重矩阵,用于对键 hih_ihi 进行线性变换。
  • Dot (纯点积):
    eti=stThie_{ti} = s_t^T h_i eti=stThi
    • 这是General的特例,省略了权重矩阵 WhW_hWh,计算速度最快。
  • Concat (与Bahdanau类似):
    eti=vaTtanh⁡(Wast+Uahi)e_{ti} = v_a^T \tanh(W_a s_t + U_a h_i) eti=vaTtanh(Wast+Uahi)
    • 形式上与Bahdanau类似,但同样使用当前状态 sts_tst 作为Query。

步骤二:分数归一化
与Bahdanau完全相同,使用Softmax函数。
αti=exp⁡(eti)∑j=1Texp⁡(etj)\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_{j=1}^{T} \exp(e_{tj})} αti=j=1Texp(etj)exp(eti)

步骤三:计算上下文向量
与Bahdanau完全相同。
ct=∑i=1Tαtihic_t = \sum_{i=1}^{T} \alpha_{ti} h_i ct=i=1Tαtihi

步骤四:生成输出
Luong注意力在生成输出时有两种不同的结构:

  • Concat:
    st=RNN_decoder(concat(ct,st))s_t = \text{RNN\_decoder}(\text{concat}(c_t, s_t)) st=RNN_decoder(concat(ct,st))
    • 这里将上下文向量 ctc_tct当前生成的隐藏状态 sts_tst 拼接。
  • General:
    y^t=softmax(Wsconcat(st,ct))\hat{y}_t = \text{softmax}(W_s \text{concat}(s_t, c_t)) y^t=softmax(Wsconcat(st,ct))
    • 直接将 sts_tstctc_tct 拼接后通过一个线性层和Softmax生成输出。

2.3 特点总结

  • Query: 解码器的当前隐藏状态 sts_tst
  • Key/Value: 编码器的隐藏状态 hih_ihi
  • 打分函数: 主要使用点积或其变体,计算效率高。
  • 对齐: 由于使用当前状态 sts_tst 作为Query,Luong注意力更偏向于单向对齐,即输出词只关注输入序列中它之前或对应的部分,这更符合翻译任务中“从左到右”的特性。

三、 对比与选择

3.1 两种注意力对比

特性Bahdanau 注意力Luong 注意力
提出时间2014年2015年
核心思想通过FFN计算Query-Key兼容性通过点积或其变体计算兼容性
Query解码器前一个隐藏状态 st−1s_{t-1}st1解码器当前隐藏状态 sts_tst
打分函数score(s, h) = v_a^T * tanh(W_a*s + U_a*h)score(s, h) = s^T * W_h * h (General)
score(s, h) = s^T * h (Dot)
计算效率较低(涉及矩阵乘法和FFN)较高(主要是矩阵乘法和点积)
对齐方向双向 (可以看到自己之前生成的所有词)单向 (更符合从左到右的生成过程)
参数量较多 (Wa,Ua,vaW_a, U_a, v_aWa,Ua,va)较少 (只有General变体有WhW_hWh)
影响力开创性的工作,首次引入注意力机制更简洁高效,成为后续许多模型的基础

3.2 两种注意力的优缺点

1、Bahdanau 注意力

  • 优点
    • 对齐灵活:双向对齐特性使其在某些复杂任务中可能表现更好。
    • 兼容性强:即使Query和Key的维度不同,也能通过权重矩阵进行匹配。
  • 缺点
    • 计算量大:FFN结构增加了计算开销,训练和推理速度较慢。
    • 参数多:更多的参数意味着需要更多的数据来训练,否则容易过拟合。

2、Luong 注意力

  • 优点
    • 计算高效:点积操作非常快,显著提升了模型性能。
    • 参数精简:模型更简单,泛化能力可能更强。
    • 对齐更合理:单向对齐更符合语言生成的直觉。
  • 缺点
    • 维度依赖:Dot变体要求数据的维度必须匹配。
    • 梯度问题:当维度很高时,点积结果的方差会很大,可能导致Softmax函数梯度消失或爆炸。为此,Luong也提出了Scaled General Attention,即在点积后除以一个缩放因子 dk\sqrt{d_k}dkdkd_kdk是键的维度),这与Transformer中的做法一致。

3.3 如何选择?

在现代深度学习实践中,Luong注意力,特别是其Scaled General变体,通常是首选

  1. 性能与效率的权衡:在绝大多数NLP任务中,Luong注意力的表现优于或至少持平于Bahdanau注意力,同时计算效率更高。
  2. 简洁性:公式更简洁,实现起来更直接,参数更少,降低了过拟合的风险。
  3. 影响力:Luong注意力的点积思想,特别是经过缩放后的版本,是Transformer模型中多头注意力机制的基石。理解了Luong注意力,就更容易理解更现代的Transformer模型。

那何时会考虑Bahdanau?: 在一些非常特殊的场景下,比如输入和隐藏状态的维度差异巨大,或者你发现点积方式确实导致了训练不稳定时,Bahdanau的加性方式可能是一个备选方案。但在绝大多数情况下,Luong的Scaled General注意力是更优、更现代的选择。

总而言之,Bahdanau注意力为我们打开了动态对齐的大门,而Luong注意力则通过更高效的数学工具,让这条路走得更远、更快。理解它们的区别,对于深入掌握现代序列模型至关重要。

http://www.dtcms.com/a/326113.html

相关文章:

  • 【浮点数存储】double类型注意点
  • 理解LangChain — Part 3:链式工作流与输出解析器
  • Notepad--:国产跨平台文本编辑器,Notepad++ 的理想替代方案
  • 写一篇Ping32和IP-Guard的对比,重点突出Ping32
  • 循环控制:break和continue用法
  • 鸿蒙flutter项目接入极光推送
  • Java项目基本流程(三)
  • Orange的运维学习日记--38.MariaDB详解与服务部署
  • linux安装和使用git
  • Elasticsearch 官方 Node.js 从零到生产
  • docker部署elasticsearch-8.11.1
  • 网络的基本概念、通信原理以及网络安全问题
  • YOLOv6深度解析:实时目标检测的新突破
  • 时序数据库为什么选IoTDB?
  • 爬虫与数据分析结合案例
  • STM32 HAL驱动MPU6050传感器
  • p6spy和p6spy-spring-boot-starter的SpringBoot3集成配置
  • 高性能Web服务器
  • java基础概念(二)----变量(附练习题)
  • Go 语言三大核心数据结构深度解析:数组、切片(Slice)与映射(Map)
  • Unity插件DOTween使用
  • 【GPT入门】第45课 无梯子,linux/win下载huggingface模型方法
  • 如何避免团队文件同步过程中版本信息的丢失?
  • GAI 与 Tesla 机器人的具体联动机制
  • 变频器与伺服系统的工作原理,干扰来源及治理方式
  • 软件测试关于搜索方面的测试用例
  • [AI 生成] kafka 面试题
  • 是否有必要使用 Oracle 向量数据库?
  • 【图像处理基石】UE输出渲染视频,有哪些画质相关的维度和标准可以参考?
  • OmniHuman:字节推出的AI项目,支持单张照片生成逼真全身动态视频