当前位置: 首页 > news >正文

从代码学习深度学习 - 用于预训练词嵌入的数据集 PyTorch版

文章目录

  • 前言
  • 辅助工具代码
    • 绘图工具 (utils\_for\_huitu.py)
    • 数据处理工具 (utils\_for\_data.py)
  • 读取数据集 (PTB)
  • 构建词表
  • 下采样高频词
  • 中心词和上下文词的提取
  • 负采样
  • 小批量加载训练实例
  • 整合代码:构建数据加载器
  • 总结


前言

词嵌入(Word Embedding)是将词语映射到低维连续向量空间的技术,它能够捕捉词语间的语义和语法关系。预训练词嵌入模型,如 Word2Vec(包括 Skip-gram 和 CBOW)和 GloVe,已经在自然语言处理 (NLP) 领域取得了巨大成功。这些模型通常在大型语料库上进行训练,学习到的词向量可以作为下游 NLP 任务的优秀特征输入。

本文将重点关注如何为预训练词嵌入模型(以 Skip-gram 和负采样为例)准备数据集。我们将使用 Penn Tree Bank (PTB) 数据集,并详细介绍从原始文本数据到可供 PyTorch 模型训练的小批量数据的完整处理流程。这个过程包括读取数据、构建词表、下采样高频词、提取中心词和上下文词、以及进行负采样。通过理解这些步骤,我们可以更好地掌握词嵌入模型训练的基础。

让我们开始吧!

完整代码:下载链接

辅助工具代码

在正式开始数据处理之前,我们先介绍两个辅助 Python 文件,它们分别提供了绘图和数据处理相关的功能。

绘图工具 (utils_for_huitu.py)

这个文件包含了一些使用 Matplotlib 进行绘图的辅助函数,例如设置图像大小、使用 SVG 格式显示以及绘制特定类型的直方图。

# --- START OF FILE utils_for_huitu.py ---# 导入必要的包
import matplotlib.pyplot as plt  # 用于创建和操作 Matplotlib 图表
from matplotlib_inline import backend_inline  # 用于在Jupyter中设置Matplotlib输出格式
from IPython import display  # 用于后续动态显示(如 Animator)
import torch  # 导入PyTorch库,用于处理张量类型的图像
import numpy as np  # 导入NumPy,可能用于数据处理
import matplotlib as mpl  # 导入Matplotlib主模块,用于设置图像属性def set_figsize(figsize=(3.5, 2.5)):"""设置matplotlib图形的大小参数:figsize: tuple[float, float] - 图形大小,形状为 (宽度, 高度),单位为英寸输出:无返回值"""plt.rcParams['figure.figsize'] = figsize  # 设置图形默认大小def use_svg_display():"""使用 SVG 格式在 Jupyter 中显示绘图输入:无输出:无返回值"""backend_inline.set_matplotlib_formats('svg')  # 设置 Matplotlib 使用 SVG 格式def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):"""绘制列表长度对的直方图,用于比较两组列表中元素长度的分布参数:legend: list[str] - 图例标签,形状为 (2,),分别对应xlist和ylist的标签xlabel: str - x轴标签ylabel: str - y轴标签xlist: list[list] - 第一组列表,形状为 (样本数量, 每个样本的元素数)ylist: list[list] - 第二组列表,形状为 (样本数量, 每个样本的元素数)输出:无返回值,但会显示生成的直方图"""set_figsize()  # 设置图形大小# plt.hist返回的三个值:# n: list[array] - 每个bin中的样本数量,形状为 (2, bin数量)# bins: array - bin的边界值,形状为 (bin数量+1,)# patches: list[list[Rectangle]] - 直方图的矩形对象,形状为 (2, bin数量)_, _, patches = plt.hist([[len(l) for l in xlist], [len(l) for l in ylist]])  # 绘制两组数据长度的直方图plt.xlabel(xlabel)  # 设置x轴标签plt.ylabel(ylabel)  # 设置y轴标签# 为第二组数据(ylist)的直方图添加斜线图案,以区分两组数据for patch in patches[1].patches:  # patches[1]是ylist对应的矩形对象列表patch.set_hatch('/')  # 设置填充图案为斜线plt.legend(legend)  # 添加图例
# --- END OF FILE utils_for_huitu.py ---

数据处理工具 (utils_for_data.py)

这个文件包含了一个用于统计词频的函数 count_corpus 和一个核心的 Vocab 类,后者用于构建词表,管理词元到索引以及索引到词元的映射。

# --- START OF FILE utils_for_data.py ---from collections import Counter  # 导入 Counter 类
# from collections import Counter  # 用于词频统计 (此行重复,已注释)
import torch  # PyTorch 核心库
from torch.utils import data  # PyTorch 数据加载工具
import numpy as np  # NumPy 用于数组操作def count_corpus(tokens):"""统计词元的频率参数:tokens: 词元列表,可以是:- 一维列表,例如 ['a', 'b']- 二维列表,例如 [['a', 'b'], ['c']]返回值:Counter: Counter 对象,统计每个词元的出现次数"""# 如果输入为空列表,直接返回空计数器if not tokens:  # 等价于 len(tokens) == 0return Counter()# 检查输入是否为二维列表if isinstance(tokens[0], list):# 将二维列表展平为一维列表flattened_tokens = [token for sublist in tokens for token in sublist]else:# 如果是一维列表,直接使用原列表flattened_tokens = tokens# 使用 Counter 统计词频并返回return Counter(flattened_tokens)class Vocab:"""文本词表类,用于管理词元及其索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):"""初始化词表Args:tokens: 输入的词元列表,可以是1D或2D列表,默认为空列表min_freq: 词元最小出现频率,小于此频率的词元将被忽略,默认为0reserved_tokens: 预留的特殊词元列表(如'<pad>'),默认为空列表"""# 处理默认参数self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []# 统计词元频率并按频率降序排序# 注意:这里应该调用类自身的 _count_corpus 方法counter = self._count_corpus(self.tokens) self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化词表,'<unk>'为未知词元,索引为0self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 添加满足最小频率要求的词元到词表for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1"""将方法标记为静态方法,无需绑定实例或类,可用类名直接调用"""@staticmethoddef _count_corpus(tokens):"""统计词元频率Args:tokens: 词元列表,可以是1D或2D列表Returns:Counter对象,统计每个词元的出现次数"""if not tokens:return Counter()if isinstance(tokens[0], list):tokens = [token for sublist in tokens for token in sublist]return Counter(tokens)def __len__(self):"""返回词表的大小"""return len(self.idx_to_token)def __getitem__(self, tokens):"""通过词元获取索引,或通过索引获取词元Args:tokens: 单个词元或词元列表/元组Returns:单个索引或索引列表"""if not isinstance(tokens, (list, tuple)):return self.token_to_idx.get(tokens, self.unk)return [self[token] for token in tokens]def to_tokens(self, indices):"""通过索引获取词元Args:indices: 单个索引或索引列表/元组Returns:单个词元或词元列表"""if not isinstance(indices, (list, tuple)):return self.idx_to_token[indices]return [self.idx_to_token[index] for index in indices]"""用于将类中的方法伪装成属性(property),从而让开发者可以用访问属性的方式(而不是调用方法的方式)来获取或操作类的内部数据"""@propertydef unk(self):"""未知词元的索引"""return 0@propertydef token_freqs(self):"""词元及其频率的列表"""return self._token_freqs
# --- END OF FILE utils_for_data.py ---

注意:在 Vocab 类的 __init__ 方法中,原代码中 counter = count_corpus(self.tokens) 应该为 counter = self._count_corpus(self.tokens)counter = Vocab._count_corpus(self.tokens) 以调用类自身的静态方法。上述代码已做此修正。

读取数据集 (PTB)

我们使用的数据集是 Penn Tree Bank (PTB)。该语料库取自“华尔街日报”的文章,分为训练集、验证集和测试集。在原始格式中,文本文件的每一行表示由空格分隔的一句话。在这里,我们将每个单词视为一个词元。

下面的 read_ptb 函数用于将PTB训练集加载到文本行的列表中。

import math
import os
import random
import torch
import numpy as np  # 补充可能需要的数值计算库def read_ptb():"""将PTB数据集加载到文本行的列表中返回:list[list[str]]: 句子列表,每个句子是词语的列表形状为 (句子数量, 每句话的词数),其中每句话的词数不固定"""data_dir = 'ptb'  # 数据集目录 (str)# 读取训练集文件# 假设 'ptb/ptb.train.txt' 文件存在且包含数据# 为确保代码可运行,如果文件不存在,可以创建一个虚拟文件或处理异常if not os.path.exists

相关文章:

  • docker默认存储迁移
  • 【Nuxt3】安装 Naive UI 按需自动引入组件
  • 【QT】一个界面中嵌入其它界面(一)
  • PyQt5绘图全攻略:QPainter、QPen、QBrush与QPixmap详解
  • 第十六届蓝桥杯复盘
  • P2P最佳网络类型
  • Fiddler无法抓包的问题分析
  • C语言学习笔记之条件编译
  • # idea 中如何将 java 项目打包成 jar 包?
  • 国家互联网信息办公室关于发布第十一批深度合成服务算法备案信息的公告
  • [架构之美]从PDMan一键生成数据库设计文档:Word导出全流程详解(二十)
  • GO语言学习(五)
  • vue3自适应高度超出折叠功能
  • 【操作系统面经】持续更新ing
  • FART 主动调用组件设计和源码分析
  • 程序化 SEO 全攻略:如何高效提升网站排名?
  • Linux 文件(2)
  • 电子电路:什么是静态工作点Q点?
  • 【QT】QT6添加现有.c .h文件
  • QT之绘图模块和双缓冲技术
  • 世卫大会连续9年拒绝涉台提案
  • 从《缶翁的世界》看吴昌硕等湖州籍书画家对海派的影响
  • 北京韩美林艺术馆党支部书记郭莹病逝,终年40岁
  • 2000多年前的“新衣”长这样!马王堆文物研究新成果上新
  • 中国人民银行等四部门联合召开科技金融工作交流推进会
  • 刘晓庆被实名举报涉嫌偷税漏税,税务部门启动调查