当前位置: 首页 > news >正文

【现代深度学习技术】注意力机制04:Bahdanau注意力

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、模型
    • 二、定义注意力解码器
    • 三、训练
    • 小结


  序列到序列学习(seq2seq)中探讨了机器翻译问题:通过设计一个基于两个循环神经网络的编码器-解码器架构,用于序列到序列学习。具体来说,循环神经网络编码器将长度可变的序列转换为固定形状的上下文变量,然后循环神经网络解码器根据生成的词元和上下文变量按词元生成输出(目标)序列词元。然而,即使并非所有输入(源)词元都对解码某个词元都有用,在每个解码步骤中仍使用编码相同的上下文变量。有什么方法能改变上下文变量呢?

  我们试着找到灵感:在为给定文本序列生成手写的挑战中,Graves设计了一种可微注意力模型,将文本字符与更长的笔迹对齐,其中对齐方式仅向一个方向移动。受学习对齐想法的启发,Bahdanau等人提出了一个没有严格单向对齐限制的可微注意力模型。在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。这是通过将上下文变量视为注意力集中的输出来实现的。

一、模型

  下面描述的Bahdanau注意力模型将遵循序列到序列学习(seq2seq)中的相同符号表达。这个新的基于注意力的模型与序列到序列学习(seq2seq)中的模型相同,只不过其中式(3)中的上下文变量 c \mathbf{c} c在任何解码时间步 t ′ t' t都会被 c t ′ \mathbf{c}_{t'} ct替换。假设输入序列中有 T T T个词元,解码时间步 t ′ t' t的上下文变量是注意力集中的输出:
c t ′ = ∑ t = 1 T α ( s t ′ − 1 , h t ) h t (1) \mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t \tag{1} ct=t=1Tα(st1,ht)ht(1) 其中,时间步 t ′ − 1 t' - 1 t1时的解码器隐状态 s t ′ − 1 \mathbf{s}_{t' - 1} st1是查询,编码器隐状态 h t \mathbf{h}_t ht既是键,也是值,注意力权重 α \alpha α是使用加性注意力打分函数计算的。

  与循环神经网络编码器-解码器架构略有不同,图1描述了Bahdanau注意力的架构。

在这里插入图片描述

图1 一个带有Bahdanau注意力的循环神经网络编码器-解码器模型

import torch
from torch import nn
from d2l import torch as d2l

二、定义注意力解码器

  下面看看如何定义Bahdanau注意力,实现循环神经网络编码器-解码器。其实,我们只需重新定义解码器即可。为了更方便地显示学习的注意力权重,以下AttentionDecoder类定义了带有注意力机制解码器的基本接口。

#@save
class AttentionDecoder(d2l.Decoder):"""带有注意力机制解码器的基本接口"""def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedError

  接下来,让我们在接下来的Seq2SeqAttentionDecoder类中实现带有Bahdanau注意力的循环神经网络解码器。首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;
  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;
  3. 编码器有效长度(排除在注意力池中填充词元)。

  在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,# num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# 输出X的形状为(num_steps,batch_size,embed_size)X = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:# query的形状为(batch_size,1,num_hiddens)query = torch.unsqueeze(hidden_state[-1], dim=1)# context的形状为(batch_size,1,num_hiddens)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)# 在特征维度上连结x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# 将x变形为(1,batch_size,embed_size+num_hiddens)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)# 全连接层变换后,outputs的形状为# (num_steps,batch_size,vocab_size)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights

  接下来,使用包含7个时间步的4个序列输入的小批量测试Bahdanau注意力解码器。

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

在这里插入图片描述

三、训练

  与序列到序列学习(seq2seq)类似,我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器,并对这个模型进行机器翻译训练。由于新增的注意力机制,训练要序列到序列学习(seq2seq)比没有注意力机制的慢得多。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在这里插入图片描述
在这里插入图片描述

  模型训练后,我们用它将几个英语句子翻译成法语并计算它们的BLEU分数。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = d2l.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

在这里插入图片描述

attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((1, 1, -1, num_steps))

  训练结束后,下面通过可视化注意力权重会发现,每个查询都会在键值对上分配不同的权重,这说明在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

# 加上一个包含序列结束词元
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')

在这里插入图片描述

小结

  • 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。
  • 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。

相关文章:

  • 17.【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--ELK
  • 数据集-目标检测系列- 冥想 检测数据集 close_eye>> DataBall
  • 引言:Client Hello 为何是 HTTPS 安全的核心?
  • 【Linux实践系列】:进程间通信:万字详解共享内存实现通信
  • # Java List完全指南:从入门到高阶应用
  • [面试]SoC验证工程师面试常见问题(五)TLM通信篇
  • Vue v-model 深度解析:实现原理与高级用法
  • uniapp-商城-48-后台 分类数据添加修改弹窗bug
  • 【含文档+源码】基于SpringBoot的新能源充电桩管理系统的设计与实现
  • 最小生成树
  • 《C++探幽:模板从初阶到进阶》
  • 【Rust】枚举和模式匹配
  • 计算机大类专业数据结构下半期实验练习题
  • 《用MATLAB玩转游戏开发:从零开始打造你的数字乐园》基础篇(2D图形交互)-俄罗斯方块:用旋转矩阵打造经典
  • python-django项目启动寻找静态页面html顺序
  • C++GO语言微服务之gorm框架操作MySQL
  • 无法更新Google Chrome的解决问题
  • [Pandas]数据处理
  • Dify使用总结
  • JVM对象创建内存分配
  • “拼好假”的年轻人,今年有哪些旅游新玩法?
  • 《中国人民银行业务领域数据安全管理办法》发布,6月30日起施行
  • 习近平同瑞典国王卡尔十六世·古斯塔夫就中瑞建交75周年互致贺电
  • 保证断电、碰撞等事故中车门系统能够开启!隐藏式门把手将迎来强制性国家标准
  • 家庭相册㉙在沪打拼25年,我理解了父母清晨去卖蜜饯的辛苦
  • 上海市委政法委召开会议传达学习总书记重要讲话精神