当前位置: 首页 > news >正文

使用注意力机制的seq2seq

使用注意力机制的seq2seq

Motivation

机器翻译中,每个生成的词可能相关于源句子中不同的词
在这里插入图片描述
但是,seq2seq 模型中不能对此直接建模

加入注意力

在这里插入图片描述

编码器对每次词的输出作为key和value(它们是同样的)
解码器RNN对上一个词的输出是query
注意力的输出和下一个词的词嵌入合并进入RNN

总结

Seq2seq中通过隐状态在编码器和解码器中传递信息
注意力机制可以根据解码器RNN的输出来匹配到合适的编码器RNN的输出来更有效的传递信息

代码实现

Bahdanau 注意力

首先导入必要的环境

import torch
from torch import nn
from d2l import torch as d2l

带有注意力机制的解码器基本接口

# @save
class AttentionDecoder(d2l.Decoder):"""带有注意力机制解码器的基本接口"""def __init__(self, **kwargs):# 调用父类 d2l.Decoder 的初始化方法# **kwargs 允许传递任意数量的关键字参数super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self):  # 有没有都不重要,主要是用来画图用的# 定义一个抽象属性,用于获取注意力权重# 子类需要实现此方法以返回注意力权重raise NotImplementedError

在接下来的Seq2SeqAttentionDecoder类中实现带有Bahdanau注意力的循环神经网络解码器。首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;
  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;
  3. 编码器有效长度(排除在注意力池中填充词元)。
    在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。
class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):# 调用父类 AttentionDecoder 的初始化方法super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)# 定义加性注意力机制self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens,num_hiddens, dropout)# 定义嵌入层,将词元索引映射到指定维度的嵌入向量self.embedding = nn.Embedding(vocab_size, embed_size)# 定义循环神经网络(GRU),输入维度为嵌入维度+注意力上下文向量维度self.rnn = nn.GRU(embed_size + num_hiddens,num_hiddens, num_layers,dropout=dropout)# 定义全连接层,用于将GRU的输出映射到词汇表大小的向量self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# 初始化解码器的状态# enc_outputs 是编码器的输出,包括所有时间步的隐状态和最终层隐状态# enc_valid_lens 是编码器的有效长度,用于掩码填充词元# outputs 的形状为 (batch_size, num_steps, num_hiddens)# hidden_state 的形状为 (num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputs# 将 outputs 的时间步维度和批量维度交换,方便后续操作return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# 前向传播# X 是输入序列,形状为 (batch_size, num_steps)# state 是解码器的状态,包括编码器输出、隐状态和有效长度enc_outputs, hidden_state, enc_valid_lens = state# 将输入序列通过嵌入层,输出形状为 (num_steps, batch_size, embed_size)X = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []  # 存储输出和注意力权重for x in X:  # 遍历每个时间步的输入# query 是解码器当前时间步的隐状态,形状为 (batch_size, 1, num_hiddens)query = torch.unsqueeze(hidden_state[-1], dim=1)# 计算注意力上下文向量,形状为 (batch_size, 1, num_hiddens)context = self.attention(query, enc_outputs, enc_outputs,enc_valid_lens)# 将当前时间步的输入和上下文向量在特征维度上拼接# 拼接后的形状为 (batch_size, 1, embed_size + num_hiddens)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# 将拼接后的输入传入GRU,输出形状为 (1, batch_size, num_hiddens)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)  # 将当前时间步的输出添加到列表中# 保存当前时间步的注意力权重self._attention_weights.append(self.attention.attention_weights)# 将所有时间步的输出拼接,并通过全连接层映射到词汇表大小# 输出形状为 (num_steps, batch_size, vocab_size)outputs = self.dense(torch.cat(outputs, dim=0))# 调整输出形状为 (batch_size, num_steps, vocab_size)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,enc_valid_lens]@propertydef attention_weights(self):# 返回注意力权重,用于可视化或其他分析return self._attention_weights

测试 Bahdanau 注意力解码器

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape# (torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

训练

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在这里插入图片描述
将几个英语句子翻译成法语并计算其 BLEU 分数

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = d2l.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f'{eng} => {translation}, ',f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

在这里插入图片描述

attention_weights = torch.cat(
[step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((1, 1, -1, num_steps))

训练结束后,下面通过可视化注意力权重。发现,每个查询都会在键值对上分配不同的权重,这说明在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

# 加上一个包含序列结束词元
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')

在这里插入图片描述

小结

  • 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。
  • 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。

QA 思考

Q1:attention在搜索的时候是在当前句子搜索,还是所有的文本搜索?
A1:在当前的句子进行搜索。

Q2:─般都是在decoder加入注意力吗,不可以在encoder加入吗?
A2:可以在 Encoder中加入注意力,bert就是只有Encoder,这样就只在Encoder中加入注意力了。

http://www.dtcms.com/a/172598.html

相关文章:

  • 【SaaS多租架构】数据隔离与性能平衡
  • 【2025最新】AI绘画终极提示词库|MidjourneyStable Diffusion通用公式大全
  • Cisco Packet Tracer 选项卡的使用
  • 【神经网络与深度学习】普通自编码器和变分自编码器的区别
  • JavaScript 实现输入框的撤销功能
  • Spring Boot多模块划分设计
  • # 机器学习实操 第二部分 神经网络和深度学习 第12章 自定义模型和训练循环
  • 15届蓝桥杯国赛 立定跳远
  • 两次解析格式化字符串 + 使用SQLAlchemy的relationship执行任意命令 -- link-shortener b01lersCTF 2025
  • 【数据治理】数据架构设计
  • 时间同步服务核心知识笔记:原理、配置与故障排除
  • 详解RabbitMQ工作模式之发布订阅模式
  • Multi Agents Collaboration OS:专属多智能体构建—基于业务场景流程构建专属多智能体
  • 网络安全自动化:精准把握自动化边界,筑牢企业安全防
  • Redis的过期设置和策略
  • Java后端程序员学习前端之CSS
  • 深入理解 Redis 的主从、哨兵与集群架构
  • 基于EFISH-SCB-RK3576工控机/SAIL-RK3576核心板的网络安全防火墙技术方案‌(国产化替代J1900的全栈技术解析)
  • DeepSeek-Prover-V2,DeepSeek推出的开源数学推理大模型
  • 【Leetcode 每日一题 - 补卡】1128. 等价多米诺骨牌对的数量
  • 旋转图像(中等)
  • 一套SaaS ERP系统源码,ERP成品系统源代码,基于SpringBoot框架
  • 1.CFD 计算过程概述:有限元仿真与CFD介绍
  • Sim Studio 是一个开源的代理工作流程构建器。Sim Studio 的界面是一种轻量级、直观的方式,可快速构建和部署LLMs与您最喜欢的工具连接
  • Android学习总结之GetX库篇(优缺点)
  • 网络延时 第四次CCF-CSP计算机软件能力认证
  • 10.施工测量
  • 基于SpringBoot + Vue 的火车票订票系统
  • opencv+opencv_contrib+cuda和VS2022编译
  • JavaScript学习教程,从入门到精通, jQuery的高亮显示图像、留言板、元素内容操作知识点及案例代码(36)