从代码学习深度学习 - 序列到序列学习 GRU编解码器 PyTorch 版
文章目录
- 前言
- 一、数据加载与预处理
-
- 1.1 读取数据
- 1.2 预处理数据
- 1.3 词元化
- 1.4 词频统计
- 1.5 构建词汇表
- 1.6 截断与填充
- 1.7 转换为张量
- 1.8 创建数据迭代器
- 1.9 整合数据加载
- 二、训练辅助工具
-
- 2.1 时间记录器
- 2.2 累加器
- 2.3 准确率计算
- 2.4 GPU 上的准确率评估
- 2.5 GPU 设备选择
- 2.6 梯度裁剪
- 三、可视化工具
-
- 3.1 SVG 显示设置
- 3.2 坐标轴设置
- 3.3 动态绘图
- 四、网络架构
-
- 4.1 编码器接口
- 4.2 解码器接口
- 4.3 编码器-解码器组合
- 4.4 模型定义
-
- GRU 编码器
- GRU 解码器
- 模型说明
- 五、数据加载与训练
-
- 5.1 数据加载
- 5.2 模型训练
-
- 损失函数定义
- 训练函数
- 训练代码
- 说明
- 六、BLEU 评估
-
- 6.1 BLEU 计算函数
- 6.2 BLEU 概念与取值范围
- 七、预测功能
-
- 7.1 预测代码
- 7.2 测试预测效果
- 总结
前言
Seq2Seq 模型的核心思想是将一个输入序列(例如英语句子)通过编码器(Encoder)转化为一个固定长度的上下文向量,再由解码器(Decoder)根据该向量生成目标序列(例如法语句子)。这种编码-解码的架构最初由 RNN 实现,后来发展出 LSTM 和 Transformer 等变种。在本文中,我们将聚焦于基于 RNN 的经典实现,并通过 PyTorch 代码逐步拆解其关键组件。
本文的代码来源于一个完整的机器翻译任务示例,数据集为英语-法语翻译对。我们将从数据加载与预处理开始,逐步构建编码器和解码器,最后通过 BLEU 分数评估翻译效果。所有代码都经过注释,确保易于理解,同时保留了附件中的完整性。
让我们开始吧!
一、数据加载与预处理
Seq2Seq 模型的第一步是准备数据。我们需要将原始的英语-法语翻译对数据加载到内存中,并对其进行预处理和词元化(tokenization),以便后续输入到模型中。以下是相关代码及其解释:
1.1 读取数据
from collections import Counter # 用于词频统计
import torch # PyTorch 核心库
from torch.utils import data # PyTorch 数据加载工具
import numpy as np # NumPy 用于数组操作
def read_data_nmt():
"""
载入“英语-法语”数据集
返回值:
str: 文件内容的完整字符串
"""
with open('fra.txt', 'r', encoding='utf-8') as f:
return f.read()
read_data_nmt
函数简单地读取名为 fra.txt
的文件,该文件包含英语和法语的翻译对,每行以制表符分隔。它返回整个文件的字符串内容,为后续处理奠定基础。
1.2 预处理数据
def preprocess_nmt(text):
"""
预处理“英语-法语”数据集
参数:
text (str): 输入的原始文本字符串
返回值:
str: 处理后的文本字符串
"""
def no_space(char, prev_char):
"""
判断当前字符是否需要前置空格
"""
return char in set(',.!?') and prev_char != ' '
# 使用空格替换不间断空格(\u202f)和非断行空格(\xa0),并转换为小写
text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
# 在单词和标点符号之间插入空格
out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
for i, char in enumerate(text)]
return ''.join(out)
preprocess_nmt
函数对文本进行标准化处理:
- 将特殊空格字符替换为普通空格,并将所有字符转换为小写。
- 在标点符号(如逗号、句号)前插入空格,便于后续按空格分割词元。这种处理确保标点符号被视为独立的词元,而不是粘附在单词上。
1.3 词元化
def tokenize_nmt(text, num_examples=None):
"""
词元化“英语-法语”数据集
参数:
text (str): 输入的文本字符串,每行包含英语和法语句子,用制表符分隔
num_examples (int, optional): 最大处理样本数,默认值为 None 表示处理全部
返回值:
tuple: 包含两个列表的元组
- source (list): 英语句子词元列表
- target (list): 法语句子词元列表
"""
source, target = [], []
for i, line in enumerate(text.split('\n')):
if num_examples and i > num_examples:
break
parts = line.split('\t')
if len(parts) == 2:
source.append(parts[0].split(' '))
target.append(parts[1].split(' '))
return source, target
tokenize_nmt
函数将预处理后的文本按行分割,并进一步将每行按制表符分为英语和法语部分,然后按空格分割成词元列表。它返回两个列表:source
(英语词元列表)和 target
(法语词元列表)。
1.4 词频统计
def count_corpus(tokens):
"""
统计词元的频率
参数:
tokens: 词元列表,可以是一维或二维列表
返回值:
Counter: Counter 对象,统计每个词元的出现次数
"""
if not tokens:
return Counter()
if isinstance(tokens[0], list):
flattened_tokens = [token for sublist in tokens for token in sublist]
else:
flattened_tokens = tokens
return Counter(flattened_tokens)
count_corpus
函数使用 Counter
类统计词元的出现频率,支持一维和二维列表输入。它是构建词汇表的基础工具。
1.5 构建词汇表
class Vocab:
"""文本词表类,用于管理词元及其索引的映射关系"""
def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
"""初始化词表"""
self.tokens = tokens if tokens is not None else []
self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
counter = self._count_corpus(self.tokens)
self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
self.idx_to_token = ['<unk>'] + self.reserved_tokens
self.token_to_idx = {
token: idx for idx, token in enumerate(self.idx_to_token)}
for token, freq in self._token_freqs:
if freq < min_freq:
break
if token not in self.token_to_idx:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
@staticmethod
def _count_corpus(tokens):
"""统计词元频率"""
if not tokens:
return Counter()
if isinstance(tokens[0], list):
tokens = [token for sublist in tokens for token in sublist]
return Counter(tokens)
def __len__(self):
return len(self.idx_to_token)
def __getitem__(self, tokens):
if not isinstance(tokens, (list, tuple)):
return self.token_to_idx.get(tokens, self.unk)
return [self[token] for token in tokens]
def to_tokens(self, indices):
if not isinstance(indices, (list, tuple)):
return self.idx_to_token[indices]
return [self.idx_to_token[index] for index in indices]
@property
def unk(self):
return 0
@property
def token_freqs(self):
return self._token_freqs
Vocab
类用于构建词汇表并管理词元与索引之间的映射:
- 初始化时接受词元列表、最小频率阈值和预留特殊词元(如
<pad>
、<bos>
、<eos>
)。 - 内部使用
count_corpus
统计词频,并按频率排序。 - 提供
__getitem__
和to_tokens
方法,分别用于词元到索引和索引到词元的转换。 <unk>
表示未知词元,默认索引为 0。
1.6 截断与填充
def truncate_pad(line, num_steps, padding_token):
"""
截断或填充文本序列
参数:
line (list): 输入的文本序列(词元列表)
num_steps (int): 目标序列长度
padding_token (str): 用于填充的标记
返回值:
list: 截断或填充后的序列,长度为 num_steps
"""
if len(line) > num_steps:
return line[:num_steps]
return line + [padding_token] * (num_steps - len(line))
truncate_pad
函数确保所有序列长度一致:
- 如果序列长度超过
num_steps
,则截断。 - 如果不足,则用
padding_token
(通常是<pad>
)填充。
1.7 转换为张量
def build_array_nmt(lines, vocab, num_steps):
"""
将机器翻译的文本序列转换为小批量
参数:
lines (list): 文本序列列表,每个元素是一个词元列表
vocab (dict): 词汇表,将词元映射为索引
num_steps (int): 目标序列长度
返回值:
tuple: 包含两个元素的元组
- array (torch.Tensor): 转换后的张量,形状为 (样本数, num_steps)
- valid_len (np.ndarray): 每个序列的有效长度,形状为 (样本数,)
"""
lines =