从代码学习深度学习 - 用于预训练词嵌入的数据集 PyTorch版
文章目录
- 前言
- 辅助工具代码
- 绘图工具 (utils\_for\_huitu.py)
- 数据处理工具 (utils\_for\_data.py)
- 读取数据集 (PTB)
- 构建词表
- 下采样高频词
- 中心词和上下文词的提取
- 负采样
- 小批量加载训练实例
- 整合代码:构建数据加载器
- 总结
前言
词嵌入(Word Embedding)是将词语映射到低维连续向量空间的技术,它能够捕捉词语间的语义和语法关系。预训练词嵌入模型,如 Word2Vec(包括 Skip-gram 和 CBOW)和 GloVe,已经在自然语言处理 (NLP) 领域取得了巨大成功。这些模型通常在大型语料库上进行训练,学习到的词向量可以作为下游 NLP 任务的优秀特征输入。
本文将重点关注如何为预训练词嵌入模型(以 Skip-gram 和负采样为例)准备数据集。我们将使用 Penn Tree Bank (PTB) 数据集,并详细介绍从原始文本数据到可供 PyTorch 模型训练的小批量数据的完整处理流程。这个过程包括读取数据、构建词表、下采样高频词、提取中心词和上下文词、以及进行负采样。通过理解这些步骤,我们可以更好地掌握词嵌入模型训练的基础。
让我们开始吧!
完整代码:下载链接
辅助工具代码
在正式开始数据处理之前,我们先介绍两个辅助 Python 文件,它们分别提供了绘图和数据处理相关的功能。
绘图工具 (utils_for_huitu.py)
这个文件包含了一些使用 Matplotlib 进行绘图的辅助函数,例如设置图像大小、使用 SVG 格式显示以及绘制特定类型的直方图。
# --- START OF FILE utils_for_huitu.py ---# 导入必要的包
import matplotlib.pyplot as plt # 用于创建和操作 Matplotlib 图表
from matplotlib_inline import backend_inline # 用于在Jupyter中设置Matplotlib输出格式
from IPython import display # 用于后续动态显示(如 Animator)
import torch # 导入PyTorch库,用于处理张量类型的图像
import numpy as np # 导入NumPy,可能用于数据处理
import matplotlib as mpl # 导入Matplotlib主模块,用于设置图像属性def set_figsize(figsize=(3.5, 2.5)):"""设置matplotlib图形的大小参数:figsize: tuple[float, float] - 图形大小,形状为 (宽度, 高度),单位为英寸输出:无返回值"""plt.rcParams['figure.figsize'] = figsize # 设置图形默认大小def use_svg_display():"""使用 SVG 格式在 Jupyter 中显示绘图输入:无输出:无返回值"""backend_inline.set_matplotlib_formats('svg') # 设置 Matplotlib 使用 SVG 格式def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):"""绘制列表长度对的直方图,用于比较两组列表中元素长度的分布参数:legend: list[str] - 图例标签,形状为 (2,),分别对应xlist和ylist的标签xlabel: str - x轴标签ylabel: str - y轴标签xlist: list[list] - 第一组列表,形状为 (样本数量, 每个样本的元素数)ylist: list[list] - 第二组列表,形状为 (样本数量, 每个样本的元素数)输出:无返回值,但会显示生成的直方图"""set_figsize() # 设置图形大小# plt.hist返回的三个值:# n: list[array] - 每个bin中的样本数量,形状为 (2, bin数量)# bins: array - bin的边界值,形状为 (bin数量+1,)# patches: list[list[Rectangle]] - 直方图的矩形对象,形状为 (2, bin数量)_, _, patches = plt.hist([[len(l) for l in xlist], [len(l) for l in ylist]]) # 绘制两组数据长度的直方图plt.xlabel(xlabel) # 设置x轴标签plt.ylabel(ylabel) # 设置y轴标签# 为第二组数据(ylist)的直方图添加斜线图案,以区分两组数据for patch in patches[1].patches: # patches[1]是ylist对应的矩形对象列表patch.set_hatch('/') # 设置填充图案为斜线plt.legend(legend) # 添加图例
# --- END OF FILE utils_for_huitu.py ---
数据处理工具 (utils_for_data.py)
这个文件包含了一个用于统计词频的函数 count_corpus
和一个核心的 Vocab
类,后者用于构建词表,管理词元到索引以及索引到词元的映射。
# --- START OF FILE utils_for_data.py ---from collections import Counter # 导入 Counter 类
# from collections import Counter # 用于词频统计 (此行重复,已注释)
import torch # PyTorch 核心库
from torch.utils import data # PyTorch 数据加载工具
import numpy as np # NumPy 用于数组操作def count_corpus(tokens):"""统计词元的频率参数:tokens: 词元列表,可以是:- 一维列表,例如 ['a', 'b']- 二维列表,例如 [['a', 'b'], ['c']]返回值:Counter: Counter 对象,统计每个词元的出现次数"""# 如果输入为空列表,直接返回空计数器if not tokens: # 等价于 len(tokens) == 0return Counter()# 检查输入是否为二维列表if isinstance(tokens[0], list):# 将二维列表展平为一维列表flattened_tokens = [token for sublist in tokens for token in sublist]else:# 如果是一维列表,直接使用原列表flattened_tokens = tokens# 使用 Counter 统计词频并返回return Counter(flattened_tokens)class Vocab:"""文本词表类,用于管理词元及其索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):"""初始化词表Args:tokens: 输入的词元列表,可以是1D或2D列表,默认为空列表min_freq: 词元最小出现频率,小于此频率的词元将被忽略,默认为0reserved_tokens: 预留的特殊词元列表(如'<pad>'),默认为空列表"""# 处理默认参数self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []# 统计词元频率并按频率降序排序# 注意:这里应该调用类自身的 _count_corpus 方法counter = self._count_corpus(self.tokens) self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化词表,'<unk>'为未知词元,索引为0self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 添加满足最小频率要求的词元到词表for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1"""将方法标记为静态方法,无需绑定实例或类,可用类名直接调用"""@staticmethoddef _count_corpus(tokens):"""统计词元频率Args:tokens: 词元列表,可以是1D或2D列表Returns:Counter对象,统计每个词元的出现次数"""if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):"""返回词表的大小"""return len(self.idx_to_token)def __getitem__(self, tokens):"""通过词元获取索引,或通过索引获取词元Args:tokens: 单个词元或词元列表/元组Returns:单个索引或索引列表"""if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):"""通过索引获取词元Args:indices: 单个索引或索引列表/元组Returns:单个词元或词元列表"""if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]"""用于将类中的方法伪装成属性(property),从而让开发者可以用访问属性的方式(而不是调用方法的方式)来获取或操作类的内部数据"""@propertydef unk(self):"""未知词元的索引"""return 0@propertydef token_freqs(self):"""词元及其频率的列表"""return self._token_freqs
# --- END OF FILE utils_for_data.py ---
注意:在 Vocab
类的 __init__
方法中,原代码中 counter = count_corpus(self.tokens)
应该为 counter = self._count_corpus(self.tokens)
或 counter = Vocab._count_corpus(self.tokens)
以调用类自身的静态方法。上述代码已做此修正。
读取数据集 (PTB)
我们使用的数据集是 Penn Tree Bank (PTB)。该语料库取自“华尔街日报”的文章,分为训练集、验证集和测试集。在原始格式中,文本文件的每一行表示由空格分隔的一句话。在这里,我们将每个单词视为一个词元。
下面的 read_ptb
函数用于将PTB训练集加载到文本行的列表中。
import math
import os
import random
import torch
import numpy as np # 补充可能需要的数值计算库def read_ptb():"""将PTB数据集加载到文本行的列表中返回:list[list[str]]: 句子列表,每个句子是词语的列表形状为 (句子数量, 每句话的词数),其中每句话的词数不固定"""data_dir = 'ptb' # 数据集目录 (str)# 读取训练集文件# 假设 'ptb/ptb.train.txt' 文件存在且包含数据# 为确保代码可运行,如果文件不存在,可以创建一个虚拟文件或处理异常if not os.path.exists