当前位置: 首页 > news >正文

从代码学习深度学习 - RNN PyTorch版

文章目录

  • 前言
  • 一、数据预处理
  • 二、辅助训练工具函数
  • 三、绘图工具函数
  • 四、模型定义
  • 五、模型训练与预测
  • 六、实例化模型并训练
    • 训练结果可视化
  • 总结


前言

循环神经网络(RNN)是深度学习中处理序列数据的重要模型,尤其在自然语言处理和时间序列分析中有着广泛应用。本篇博客将通过一个基于 PyTorch 的 RNN 实现,结合《The Time Machine》数据集,带你从零开始理解 RNN 的构建、训练和预测过程。我们将逐步剖析代码,展示如何加载数据、定义工具函数、构建模型、绘制训练过程图表,并最终训练一个字符级别的 RNN 模型。代码中包含了数据预处理、模型定义、梯度裁剪、困惑度计算等关键步骤,适合希望深入理解 RNN 的初学者和进阶者。

本文基于 PyTorch 实现,所有代码均来自附件,并辅以详细注释和图表说明。让我们开始吧!


一、数据预处理

首先,我们需要加载和预处理《The Time Machine》数据集,将其转化为适合 RNN 输入的格式。以下是数据预处理的完整代码:

import random
import re
import torch
from collections import Counter

def read_time_machine():
    """将时间机器数据集加载到文本行的列表中"""
    with open('timemachine.txt', 'r') as f:
        lines = f.readlines()
    # 去除非字母字符并将每行转换为小写
    return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]

def tokenize(lines, token='word'):
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print(f'错误:未知词元类型:{
     token}')

def count_corpus(tokens):
    """统计词元的频率"""
    if not tokens:
        return Counter()
    if isinstance(tokens[0], list):
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        flattened_tokens = tokens
    return Counter(flattened_tokens)

class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
        counter = self._count_corpus(self.tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.idx_to_token = ['<unk>'] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    @staticmethod
    def _count_corpus(tokens):
        if not tokens:
            return Counter()
        if isinstance(tokens[0], list):
            tokens = [token for sublist in tokens for token in sublist]
        return Counter(tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self[token] for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self._token_freqs

def load_corpus_time_machine(max_tokens=-1):
    """返回时光机器数据集的词元索引列表和词表"""
    lines = read_time_machine(
http://www.dtcms.com/a/109322.html

相关文章:

  • 浙江大学郑小林教授解读智能金融与AI的未来|附PPT下载方法
  • 电子电气架构 --- 面向服务的体系架构
  • Python垃圾回收:循环引用检测算法实现
  • 【面试题】如何用两个线程轮流输出0-200的值
  • 大模型应用初学指南
  • Linux 查找文本中控制字符所在的行
  • 线性欧拉筛
  • AF3 OpenFoldDataset类解读
  • 【面试篇】Kafka
  • 记录学习的第二十天
  • 【LeetCode 题解】数据库:626.换座位
  • Java基础:Logback日志框架
  • C# 与 相机连接
  • 接收灵敏度的基本概念与技术解析
  • 【计网】作业三
  • 2025年2月,美国发布了新版移动灯的安规标准:UL153标准如何办理?
  • MySQL:库表操作
  • CATIA装配体全自动存储解决方案开发实战——基于递归算法的产品结构树批量处理技术
  • 一款非常小的软件,操作起来非常丝滑!
  • 语音识别播报人工智能分类垃圾桶(论文+源码)
  • MySQL 基础使用指南-MySQL登录与远程登录
  • MySQL超全笔记
  • 快速掌握MCP——Spring AI MCP包教包会
  • Pyspark学习二:快速入门基本数据结构
  • 4月3号.
  • Python 函数知识梳理与经典编程题解析
  • FFmpeg录制屏幕和音频
  • 单片机学习之定时器
  • 嵌入式海思Hi3861连接华为物联网平台操作方法
  • Zapier MCP:重塑跨应用自动化协作的技术实践