当前位置: 首页 > news >正文

深度学习处理文本(13)

我们使用基于GRU的编码器和解码器来在Keras中实现这一方法。选择GRU而不是LSTM,会让事情变得简单一些,因为GRU只有一个状态向量,而LSTM有多个状态向量。首先是编码器,如代码清单11-28所示。

代码清单11-28 基于GRU的编码器

from tensorflow import keras
from tensorflow.keras import layers

embed_dim = 256
latent_dim = 1024

source = keras.Input(shape=(None,), dtype="int64", name="english")----不要忘记掩码,它对这种方法来说很重要
x = layers.Embedding(vocab_size, embed_dim, mask_zero=True)(source)----这是英语源句子。指定输入名称,我们就可以用输入组成的字典来拟合模型
encoded_source = layers.Bidirectional(
    layers.GRU(latent_dim), merge_mode="sum")(x)----编码后的源句子即为双向GRU的最后一个输出

接下来,我们来添加解码器——一个简单的GRU层,其初始状态为编码后的源句子。我们再添加一个Dense层,为每个输出时间步生成一个在西班牙语词表上的概率分布,如代码清单11-29所示。

代码清单11-29 基于GRU的解码器与端到端模型

past_target = keras.Input(shape=(None,), dtype="int64", name="spanish")----这是西班牙语目标句子
x = layers.Embedding(vocab_size, embed_dim, mask_zero=True)(past_target)----不要忘记使用掩码
decoder_gru = layers.GRU(latent_dim, return_sequences=True)
x = decoder_gru(x, initial_state=encoded_source)----编码后的源句子作为解码器GRU的初始状态
x = layers.Dropout(0.5)(x)
target_next_step = layers.Dense(vocab_size, activation="softmax")(x)----预测下一个词元
seq2seq_rnn = keras.Model([source, past_target], target_next_step)----端到端模型:将源句子和目标句子映射为偏移一个时间步的目标句子

训练过程中,解码器接收整个目标序列作为输入,但由于RNN逐步处理的性质,它将仅通过查看输入中第0~N个词元来预测输出的第N个词元(对应于句子的下一个词元,因为输出需要偏移一个时间步)​。这意味着我们只能使用过去的信息来预测未来——我们也应该这样做,否则就是在作弊,这样生成模型在推断过程中将不会生效。下面开始训练模型,如代码清单11-30所示。

代码清单11-30 训练序列到序列循环模型

seq2seq_rnn.compile(
    optimizer="rmsprop",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"])
seq2seq_rnn.fit(train_ds, epochs=15, validation_data=val_ds)

我们选择精度来粗略监控训练过程中的验证集性能。模型精度为64%,也就是说,平均而言,该模型在64%的时间里正确预测了西班牙语句子的下一个单词。然而在实践中,对于机器翻译模型而言,下一个词元精度并不是一个很好的指标,因为它会假设:在预测第N+1个词元时,已经知道了从0到N的正确的目标词元。实际上,在推断过程中,你需要从头开始生成目标句子,不能认为前面生成的词元都是100%正确的。现实世界中的机器翻译系统可能会使用“BLEU分数”来评估模型。这个指标会评估整个生成序列,并且看起来与人类对翻译质量的评估密切相关。最后,我们使用模型进行推断,如代码清单11-31所示。我们从测试集中挑选几个句子,并观察模型如何翻译它们。我们首先将种子词元"[start]“与编码后的英文源句子一起输入解码器模型。我们得到下一个词元的预测结果,并不断将其重新输入解码器,每次迭代都采样一个新的目标词元,直到遇到”[end]"或达到句子的最大长度。

代码清单11-31 利用RNN编码器和RNN解码器来翻译新句子

import numpy as np
spa_vocab = target_vectorization.get_vocabulary()---- (本行及以下1)准备一个字典,将词元索引预测值映射为字符串词元
spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
max_decoded_sentence_length = 20

def decode_sequence(input_sentence):
    tokenized_input_sentence = source_vectorization([input_sentence])
    decoded_sentence = "[start]"----种子词元
    for i in range(max_decoded_sentence_length):
        tokenized_target_sentence = target_vectorization([decoded_sentence])
        next_token_predictions = seq2seq_rnn.predict(---- (本行及以下2)对下一个词元进行采样
            [tokenized_input_sentence, tokenized_target_sentence])
        sampled_token_index = np.argmax(next_token_predictions[0, i, :])
        sampled_token = spa_index_lookup[sampled_token_index]---- (本行及以下1)将下一个词元预测值转换为字符串,并添加到生成的句子中
        decoded_sentence += " " + sampled_token
        if sampled_token == "[end]":----退出条件:达到最大长度或遇到停止词元
            break
    return decoded_sentence

test_eng_texts = [pair[0] for pair in test_pairs]
for _ in range(20):
    input_sentence = random.choice(test_eng_texts)
    print("-")
    print(input_sentence)
    print(decode_sequence(input_sentence))

请注意,这种推断方法虽然非常简单,但效率很低,因为每次采样新词时,都需要重新处理整个源句子和生成的整个目标句子。在实际应用中,你会将编码器和解码器分成两个独立的模型,在每次采样词元时,解码器只运行一步,并重新使用之前的内部状态。翻译结果如代码清单11-32所示。对于一个玩具模型而言,这个模型的效果相当好,尽管它仍然会犯许多低级错误。

代码清单11-32 循环翻译模型的一些结果示例

Who is in this room?
[start] quién está en esta habitación [end]
-
That doesn't sound too dangerous.
[start] eso no es muy difícil [end]
-
No one will stop me.
[start] nadie me va a hacer [end]
-
Tom is friendly.
[start] tom es un buen [UNK] [end]

有很多方法可以改进这个玩具模型。编码器和解码器可以使用多个循环层堆叠(请注意,对于解码器来说,这会使状态管理变得更加复杂)​,我们还可以使用LSTM代替GRU,诸如此类。然而,除了这些调整,RNN序列到序列学习方法还受到一些根本性的限制。源序列表示必须完整保存在编码器状态向量中,这极大地限制了待翻译句子的长度和复杂度。这有点像一个人完全凭记忆翻译一句话,并且在翻译时只能看一次源句子。RNN很难处理非常长的序列,因为它会逐渐忘记过去。等到处理序列中的第100个词元时,模型关于序列开始的信息已经几乎没有了。这意味着基于RNN的模型无法保存长期上下文,而这对于翻译长文档而言至关重要。正是由于这些限制,机器学习领域才采用Transformer架构来解决序列到序列问题。我们来看一下。

http://www.dtcms.com/a/113062.html

相关文章:

  • SSL证书过期会有什么影响
  • 奈氏准则和 香农定理
  • netty中的ServerBootstrap详解
  • thinkphp8.0上传图片到阿里云对象存储(oss)
  • 2025全新开源双端系统源码:获取通讯录、相册、短信、定位及已装应用信息
  • 程序环境和预处理
  • 第二章日志分析-redis应急响应笔记
  • 贪心算法的使用条件
  • 通义灵码:引领 AI 驱动的编程革命
  • 趣味逆商测试:了解你的逆境应对能力
  • 系统思考:思考的快与慢
  • 二叉树的前序中序后序遍历
  • DeFi漏洞利用与安全防护
  • Oracle数据库数据编程SQL<8 文本编辑器Notepad++和UltraEdit(UE)对比>
  • Python 变量
  • JVM虚拟机篇(二):深入剖析Java与元空间(MetaSpace)
  • 31信号和槽_信号和槽存在的意义(1)
  • bge-m3+deepseek-v2-16b+离线语音能力实现离线文档向量化问答语音版
  • AI绘画中的LoRa是什么?
  • Maven 远程仓库推送方法
  • Redis内存碎片详解!
  • samba共享配置
  • CodeCraft-22 and Codeforces Round 795 (Div. 2) D
  • 【网络安全论文】筑牢局域网安全防线:策略、技术与实战分析
  • Nginx介绍及使用
  • 美团滑块 分析
  • 【问题记录】C语言一个程序bug定位记录?(定义指针数组忘记[])
  • Pgvector的安装
  • 为什么AI需要向量数据库?
  • Redis数据结构之Hash