从代码学习深度学习 - RNN PyTorch版
文章目录
- 前言
- 一、数据预处理
- 二、辅助训练工具函数
- 三、绘图工具函数
- 四、模型定义
- 五、模型训练与预测
- 六、实例化模型并训练
-
- 训练结果可视化
- 总结
前言
循环神经网络(RNN)是深度学习中处理序列数据的重要模型,尤其在自然语言处理和时间序列分析中有着广泛应用。本篇博客将通过一个基于 PyTorch 的 RNN 实现,结合《The Time Machine》数据集,带你从零开始理解 RNN 的构建、训练和预测过程。我们将逐步剖析代码,展示如何加载数据、定义工具函数、构建模型、绘制训练过程图表,并最终训练一个字符级别的 RNN 模型。代码中包含了数据预处理、模型定义、梯度裁剪、困惑度计算等关键步骤,适合希望深入理解 RNN 的初学者和进阶者。
本文基于 PyTorch 实现,所有代码均来自附件,并辅以详细注释和图表说明。让我们开始吧!
一、数据预处理
首先,我们需要加载和预处理《The Time Machine》数据集,将其转化为适合 RNN 输入的格式。以下是数据预处理的完整代码:
import random
import re
import torch
from collections import Counter
def read_time_machine():
"""将时间机器数据集加载到文本行的列表中"""
with open('timemachine.txt', 'r') as f:
lines = f.readlines()
# 去除非字母字符并将每行转换为小写
return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]
def tokenize(lines, token='word'):
"""将文本行拆分为单词或字符词元"""
if token == 'word':
return [line.split() for line in lines]
elif token == 'char':
return [list(line) for line in lines]
else:
print(f'错误:未知词元类型:{
token}')
def count_corpus(tokens):
"""统计词元的频率"""
if not tokens:
return Counter()
if isinstance(tokens[0], list):
flattened_tokens = [token for sublist in tokens for token in sublist]
else:
flattened_tokens = tokens
return Counter(flattened_tokens)
class Vocab:
"""文本词表类,用于管理词元及其索引的映射关系"""
def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
self.tokens = tokens if tokens is not None else []
self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
counter = self._count_corpus(self.tokens)
self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
self.idx_to_token = ['<unk>'] + self.reserved_tokens
self.token_to_idx = {
token: idx for idx, token in enumerate(self.idx_to_token)}
for token, freq in self._token_freqs:
if freq < min_freq:
break
if token not in self.token_to_idx:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
@staticmethod
def _count_corpus(tokens):
if not tokens:
return Counter()
if isinstance(tokens[0], list):
tokens = [token for sublist in tokens for token in sublist]
return Counter(tokens)
def __len__(self):
return len(self.idx_to_token)
def __getitem__(self, tokens):
if not isinstance(tokens, (list, tuple)):
return self.token_to_idx.get(tokens, self.unk)
return [self[token] for token in tokens]
def to_tokens(self, indices):
if not isinstance(indices, (list, tuple)):
return self.idx_to_token[indices]
return [self.idx_to_token[index] for index in indices]
@property
def unk(self):
return 0
@property
def token_freqs(self):
return self._token_freqs
def load_corpus_time_machine(max_tokens=-1):
"""返回时光机器数据集的词元索引列表和词表"""
lines = read_time_machine(