当前位置: 首页 > news >正文

从代码学习深度学习 - 近似训练 PyTorch版

文章目录

  • 前言
  • 负采样 (Negative Sampling)
  • 层序Softmax (Hierarchical Softmax)
    • 代码示例
  • 总结


前言

在自然语言处理(NLP)领域,词嵌入(Word Embeddings)技术如Word2Vec(包括Skip-gram和CBOW模型)已经成为一项基础且强大的工具。它们能够将词语映射到低维稠密向量空间,使得语义相近的词在向量空间中的距离也相近。然而,这些模型在训练过程中,尤其是在计算输出层softmax时,会面临一个巨大的挑战:词汇表通常非常庞大(几十万甚至数百万个词)。对整个词典进行求和并计算梯度,其计算成本是巨大的。

为了解决这个问题,研究者们提出了多种近似训练方法,旨在降低计算复杂度,同时保持模型性能。本篇将重点介绍两种在Word2Vec中广泛应用的近似训练方法:负采样(Negative Sampling)分层Softmax(Hierarchical Softmax)。我们将以跳元模型(Skip-gram)为例来阐述这两种方法的核心思想。

虽然本文标题带有"PyTorch版",但所提供的笔记主要集中在理论层面。在实际的PyTorch应用中,这些近似训练方法通常会通过专门的损失函数或者自定义神经网络层来实现。

完整代码:下载链接

负采样 (Negative Sampling)

负采样通过修改原始目标函数来降低计算复杂度。其核心思想是,对于每个训练样本(中心词和其上下文中的一个真实目标词),我们不再尝试预测整个词汇表中哪个词是正确的上下文词,而是将其转化为一个二分类问题:区分真实的目标词和一些随机采样的“噪声”词(负样本)。

给定中心词 w c w_c wc 的上下文窗口,任意上下文词 w o w_o wo 来自该上下文窗口的事件被认为是由下式建模概率的事件:

P ( D = 1 ∣ w c , w o ) = σ ( u o ⊤ v c ) P(D=1 \mid w_c, w_o) = \sigma(\mathbf{u}_o^\top \mathbf{v}_c) P(D=1wc,wo)=σ(uovc)

其中 σ \sigma σ 使用了sigmoid激活函数的定义:

σ ( x ) = 1 1 + exp ⁡ ( − x ) \sigma(x) = \frac{1}{1 + \exp(-x)} σ(x)=1+exp(x)1

u o \mathbf{u}_o uo 是上下文词 w o w_o wo 的输出向量(或称为上下文向量), v c \mathbf{v}_c vc 是中心词 w c w_c wc 的输入向量(或称为词向量)。

原始的Word2Vec模型旨在最大化文本序列中所有这些正样本事件的联合概率。具体而言,给定长度为 T T T 的文本序列,以 w ( t ) w^{(t)} w(t) 表示时间步 t t t 的词,并使上下文窗口为 m m m,考虑最大化联合概率:

∏ t = 1 T ∏ − m ≤ j ≤ m , j ≠ 0 P ( D = 1 ∣ w ( t ) , w ( t + j ) ) \prod_{t=1}^T \prod_{-m \leq j \leq m, j \neq 0} P(D=1 \mid w^{(t)}, w^{(t+j)}) t=1Tmjm,j=0P(D=1w(t),w(t+j))

然而,这个目标函数只考虑了正样本。如果仅最大化这个概率,模型可能会学到将所有词向量都变得非常大,导致 σ ( u o ⊤ v c ) \sigma(\mathbf{u}_o^\top \mathbf{v}_c) σ(uovc) 接近1,但这并没有实际意义。

为了使目标函数更有意义,负采样引入了负样本。

S S S 表示上下文词 w o w_o wo 来自中心词 w c w_c wc 的上下文窗口的事件。对于这个涉及 w o w_o wo 的事件,我们从一个预定义的分布 P ( w ) P(w) P(w)(通常是词频的3/4次方)中采样 K K K 个不是来自这个上下文窗口的“噪声词”(负样本)。用 N k N_k Nk 表示噪声词 w k ( k = 1 , … , K ) w_k (k=1, \ldots, K) wk(k=1,,K) 不是来自 w c w_c wc 的上下文窗口的事件(即它们是负样本, D = 0 D=0 D=0)。

假设正例和负例 S , N 1 , … , N K S, N_1, \ldots, N_K S,N1,,NK 的这些事件是相互独立的。负采样将上述联合概率(仅涉及正例)修改为,对于每个中心词-上下文词对 ( w ( t ) , w ( t + j ) ) (w^{(t)}, w^{(t+j)}) (w(t),w(t+j)),最大化以下概率࿱

相关文章:

  • [强化学习的数学原理—赵世钰老师]学习笔记02-贝尔曼方程-下
  • 【AWS】从 0 基础直觉性地理解 IAM(Identity and Access Management)
  • CudaMemCpy returns cudaErrorInvalidValue
  • 《Vite 报错》ReferenceError: module is not defined in ES module scope
  • 学习黑客Active Directory入门
  • 重读《人件》Peopleware -(10-2)Ⅱ 办公环境 Ⅲ 节省办公空间的费用(下)
  • 多头自注意力机制—Transformer模型的并行特征捕获引擎
  • 打卡Day29
  • Vue百日学习计划Day24-28天详细计划-Gemini版
  • C++中的容器
  • Spring Boot JWT认证示例项目
  • 怎样免费开发部署自己的网站?
  • react深入2 - react-redux
  • MySQL——6、内置函数
  • 2025年- H31-Lc139- 242.回文链表(快慢指针)---java版--需2刷
  • c++编写中遇见的错误
  • 如何利用DeepSeek提升工作效率
  • LaTeX OCR - 数学公式识别系统
  • matlab分段函数
  • 大模型解析:AI技术的现状、原理与应用前景
  • 刘小涛任江苏省委副书记
  • 广西壮族自治区党委副书记、自治区政府主席蓝天立接受审查调查
  • 政企共同发力:多地密集部署外贸企业抢抓90天政策窗口期
  • 现场丨在胡适施蛰存等手札与文献间,再看百年光华
  • 时隔3年俄乌直接谈判今日有望重启:谁参加,谈什么
  • 3年多来俄乌要首次直接对话?能谈得拢吗?