当前位置: 首页 > news >正文

从代码学习深度学习 - 序列到序列学习 GRU编解码器 PyTorch 版

文章目录

  • 前言
  • 一、数据加载与预处理
    • 1.1 读取数据
    • 1.2 预处理数据
    • 1.3 词元化
    • 1.4 词频统计
    • 1.5 构建词汇表
    • 1.6 截断与填充
    • 1.7 转换为张量
    • 1.8 创建数据迭代器
    • 1.9 整合数据加载
  • 二、训练辅助工具
    • 2.1 时间记录器
    • 2.2 累加器
    • 2.3 准确率计算
    • 2.4 GPU 上的准确率评估
    • 2.5 GPU 设备选择
    • 2.6 梯度裁剪
  • 三、可视化工具
    • 3.1 SVG 显示设置
    • 3.2 坐标轴设置
    • 3.3 动态绘图
  • 四、网络架构
    • 4.1 编码器接口
    • 4.2 解码器接口
    • 4.3 编码器-解码器组合
    • 4.4 模型定义
      • GRU 编码器
      • GRU 解码器
      • 模型说明
  • 五、数据加载与训练
    • 5.1 数据加载
    • 5.2 模型训练
      • 损失函数定义
      • 训练函数
      • 训练代码
      • 说明
  • 六、BLEU 评估
    • 6.1 BLEU 计算函数
    • 6.2 BLEU 概念与取值范围
  • 七、预测功能
    • 7.1 预测代码
    • 7.2 测试预测效果
  • 总结


前言

Seq2Seq 模型的核心思想是将一个输入序列(例如英语句子)通过编码器(Encoder)转化为一个固定长度的上下文向量,再由解码器(Decoder)根据该向量生成目标序列(例如法语句子)。这种编码-解码的架构最初由 RNN 实现,后来发展出 LSTM 和 Transformer 等变种。在本文中,我们将聚焦于基于 RNN 的经典实现,并通过 PyTorch 代码逐步拆解其关键组件。

本文的代码来源于一个完整的机器翻译任务示例,数据集为英语-法语翻译对。我们将从数据加载与预处理开始,逐步构建编码器和解码器,最后通过 BLEU 分数评估翻译效果。所有代码都经过注释,确保易于理解,同时保留了附件中的完整性。

让我们开始吧!


一、数据加载与预处理

Seq2Seq 模型的第一步是准备数据。我们需要将原始的英语-法语翻译对数据加载到内存中,并对其进行预处理和词元化(tokenization),以便后续输入到模型中。以下是相关代码及其解释:

1.1 读取数据

from collections import Counter  # 用于词频统计
import torch  # PyTorch 核心库
from torch.utils import data  # PyTorch 数据加载工具
import numpy as np  # NumPy 用于数组操作

def read_data_nmt():
    """
    载入“英语-法语”数据集
    
    返回值:
        str: 文件内容的完整字符串
    """
    with open('fra.txt', 'r', encoding='utf-8') as f:
        return f.read()

read_data_nmt 函数简单地读取名为 fra.txt 的文件,该文件包含英语和法语的翻译对,每行以制表符分隔。它返回整个文件的字符串内容,为后续处理奠定基础。

1.2 预处理数据

def preprocess_nmt(text):
    """
    预处理“英语-法语”数据集
    
    参数:
        text (str): 输入的原始文本字符串
    
    返回值:
        str: 处理后的文本字符串
    """
    def no_space(char, prev_char):
        """
        判断当前字符是否需要前置空格
        """
        return char in set(',.!?') and prev_char != ' '

    # 使用空格替换不间断空格(\u202f)和非断行空格(\xa0),并转换为小写
    text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    
    # 在单词和标点符号之间插入空格
    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
           for i, char in enumerate(text)]
           
    return ''.join(out)

preprocess_nmt 函数对文本进行标准化处理:

  1. 将特殊空格字符替换为普通空格,并将所有字符转换为小写。
  2. 在标点符号(如逗号、句号)前插入空格,便于后续按空格分割词元。这种处理确保标点符号被视为独立的词元,而不是粘附在单词上。

1.3 词元化

def tokenize_nmt(text, num_examples=None):
    """
    词元化“英语-法语”数据集
    
    参数:
        text (str): 输入的文本字符串,每行包含英语和法语句子,用制表符分隔
        num_examples (int, optional): 最大处理样本数,默认值为 None 表示处理全部
    
    返回值:
        tuple: 包含两个列表的元组
            - source (list): 英语句子词元列表
            - target (list): 法语句子词元列表
    """
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if num_examples and i > num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source, target

tokenize_nmt 函数将预处理后的文本按行分割,并进一步将每行按制表符分为英语和法语部分,然后按空格分割成词元列表。它返回两个列表:source(英语词元列表)和 target(法语词元列表)。

1.4 词频统计

def count_corpus(tokens):
    """
    统计词元的频率
    
    参数:
        tokens: 词元列表,可以是一维或二维列表
    
    返回值:
        Counter: Counter 对象,统计每个词元的出现次数
    """
    if not tokens:
        return Counter()
    if isinstance(tokens[0], list):
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        flattened_tokens = tokens
    return Counter(flattened_tokens)

count_corpus 函数使用 Counter 类统计词元的出现频率,支持一维和二维列表输入。它是构建词汇表的基础工具。

1.5 构建词汇表

class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""

    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        """初始化词表"""
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
        counter = self._count_corpus(self.tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.idx_to_token = ['<unk>'] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    @staticmethod
    def _count_corpus(tokens):
        """统计词元频率"""
        if not tokens:
            return Counter()
        if isinstance(tokens[0], list):
            tokens = [token for sublist in tokens for token in sublist]
        return Counter(tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self[token] for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self._token_freqs

Vocab 类用于构建词汇表并管理词元与索引之间的映射:

  • 初始化时接受词元列表、最小频率阈值和预留特殊词元(如 <pad><bos><eos>)。
  • 内部使用 count_corpus 统计词频,并按频率排序。
  • 提供 __getitem__to_tokens 方法,分别用于词元到索引和索引到词元的转换。
  • <unk> 表示未知词元,默认索引为 0。

1.6 截断与填充

def truncate_pad(line, num_steps, padding_token):
    """
    截断或填充文本序列
    
    参数:
        line (list): 输入的文本序列(词元列表)
        num_steps (int): 目标序列长度
        padding_token (str): 用于填充的标记
    
    返回值:
        list: 截断或填充后的序列,长度为 num_steps
    """
    if len(line) > num_steps:
        return line[:num_steps]
    return line + [padding_token] * (num_steps - len(line))

truncate_pad 函数确保所有序列长度一致:

  • 如果序列长度超过 num_steps,则截断。
  • 如果不足,则用 padding_token(通常是 <pad>)填充。

1.7 转换为张量

def build_array_nmt(lines, vocab, num_steps):
    """
    将机器翻译的文本序列转换为小批量
    
    参数:
        lines (list): 文本序列列表,每个元素是一个词元列表
        vocab (dict): 词汇表,将词元映射为索引
        num_steps (int): 目标序列长度
    
    返回值:
        tuple: 包含两个元素的元组
            - array (torch.Tensor): 转换后的张量,形状为 (样本数, num_steps)
            - valid_len (np.ndarray): 每个序列的有效长度,形状为 (样本数,)
    """
    lines =

相关文章:

  • C# 常量
  • QScrcpy源码解析(1)
  • MOP数据库中的EXPLAIN用法
  • 初识 rsync:高效同步文件的利器(含 rsync -av 详解)
  • 【GESP】C++二级练习 luogu-B3721 [语言月赛202303] Stone Gambling S
  • VR体验馆如何用小程序高效引流?3步打造线上预约+团购裂变系统
  • LeetCode 解题思路 33(Hot 100)
  • Spring集成asyncTool:实现复杂任务的优雅编排与高效执行
  • 学习需要回看笔记
  • C语言 数据结构【双向链表】动态模拟实现
  • 11. grafana的table表使用
  • [随记] 安装 docker 报错排查
  • Docker 入门指南:基础知识解析
  • 【C++初学】C++实现通讯录管理系统:从零开始的详细教程
  • 道路坑洼目标检测数据集-665-labelme
  • Linux系统学习Day1——虚拟机间的讲话
  • 五子棋游戏开发:静态资源的重要性与设计思路
  • WPF 资源加载问题:真是 XAML 的锅吗?
  • [MySQL数据库] InnoDB存储引擎(二) : 磁盘结构详解
  • 智慧景区能源管理解决方案,为旅游“升温”保驾护航
  • 网站建设汇报评估/太原网络营销公司
  • 宝安沙井天气/济南seo公司
  • 网站备案需要什么资料/百度推广竞价排名技巧
  • 广告网站建设制作设计服务商/seo优化推荐
  • 网站背景如何做/seo快速排名优化方法
  • 政府网站建设参考书/四川seo平台