当前位置: 首页 > news >正文

Masked Attention 在 LLM 训练中的作用与原理

大语言模型(LLM)训练过程中,Masked Attention(掩码注意力) 是一个关键机制,它决定了 模型如何在训练时只利用过去的信息,而不会看到未来的 token。这篇文章将帮助你理解 Masked Attention 的作用、实现方式,以及为什么它能确保当前 token 只依赖于过去的 token,而不会泄露未来的信息。

1. Masked Attention 在 LLM 训练中的作用

在 LLM 训练时,我们通常使用 自回归(Autoregressive) 方式来让模型学习文本的生成。例如,给定输入序列:

"The cat is very"

模型需要预测下一个 token:

"cute"

但是,为了保证模型的生成方式符合自然语言流向,每个 token 只能看到它之前的 token,不能看到未来的 token

Masked Attention 的作用就是:

  • 屏蔽未来的 token,使当前 token 只能关注之前的 token
  • 保证训练阶段的注意力机制符合推理时的因果(causal)生成方式
  • 防止信息泄露,让模型学会自回归生成文本

如果没有 Masked Attention,模型在训练时可以“偷看”未来的 token,导致它学到的规律无法泛化到推理阶段,从而影响文本生成的效果。

举例说明

假设输入是 "The cat is cute",模型按 token 级别计算注意力:

(1) 没有 Mask(BERT 方式)
TokenThecatiscute
The
cat
is
cute

每个 token 都能看到整个句子,适用于 BERT 这种双向模型。

(2) 有 Mask(GPT 方式)
TokenThecatiscute
The
cat
is
cute

每个 token 只能看到它自己及之前的 token,保证训练和推理时的生成顺序一致。

2. Masked Attention 的工作原理

 在标准的 自注意力(Self-Attention) 机制中,注意力分数是这样计算的:

A = \text{softmax} \left( \frac{Q K^T}{\sqrt{d_k}} \right)

其中:

  • Q, K, V  是 Query(查询)、Key(键)和 Value(值)矩阵

  • Q K^T 计算所有 token 之间的相似度

  • 如果不做 Masking,每个 token 都能看到所有的 token

而在 Masked Attention 中,我们会使用一个 上三角掩码(Upper Triangular Mask),使得未来的 token 不能影响当前 token:

S' = \frac{Q K^T}{\sqrt{d_k}} + \text{mask}

Mask 是一个 上三角矩阵,其中:

  • 未来 token 的位置填充 -\infty,确保 softmax 之后它们的注意力权重为 0

  • 只允许关注当前 token 及之前的 token

例如,假设有 4 个 token:

\begin{bmatrix} s_{1,1} & -\infty & -\infty & -\infty \\ s_{2,1} & s_{2,2} & -\infty & -\infty \\ s_{3,1} & s_{3,2} & s_{3,3} & -\infty \\ s_{4,1} & s_{4,2} & s_{4,3} & s_{4,4} \end{bmatrix}

经过 softmax 之后:

A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ \text{non-zero} & \text{non-zero} & 0 & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & \text{non-zero} \end{bmatrix}

最终,每个 token 只会关注它自己和它之前的 token,完全忽略未来的 token!

3. Masked Attention 计算下三角部分的值时,如何保证未来信息不会泄露?

换句话说,我们需要证明 Masked Attention 计算出的下三角部分的值(即历史 token 之间的注意力分数)不会受到未来 token 的影响

1. 问题重述

Masked Attention 的核心计算是:

\text{Attention}(Q, K, V) = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}} + \text{mask}) V

其中:

  • Q, K, V 是整个序列的矩阵。

  • QK^T计算的是所有 token 之间的注意力分数。

  • Mask 确保 softmax 后未来 token 的注意力分数变为 0。

这个问题可以分解成两个关键点:

  1. 未来 token 是否影响了下三角部分的 Q 或 K?

  2. 即使未来 token 参与了 Q, K 计算,为什么它们不会影响下三角的注意力分数?

2. 未来 token 是否影响了 Q 或 K?

我们先看 Transformer 计算 Q, K, V 的方式:

Q = X W_Q, \quad K = X W_K, \quad V = X W_V

这里:

  • X 是整个输入序列的表示。

  • W_Q, W_K, W_V是相同的投影矩阵,作用于所有 token。

由于 每个 token 的 Q, K, V 只取决于它自己,并不会在计算时使用未来 token 的信息,所以:

  • 计算第 i 个 token 的 Q_i, K_i, V_i时,并没有用到 X_{i+1}, X_{i+2}, \dots,所以未来 token 并不会影响当前 token 的 Q, K, V

结论 1未来 token 不会影响当前 token 的 Q 和 K。

3. Masked Attention 如何确保下三角部分不包含未来信息?

即使 Q, K 没有未来信息,我们仍然要证明 计算出的注意力分数不会受到未来信息影响

我们来看注意力计算:

\frac{Q K^T}{\sqrt{d_k}}

这是一个 所有 token 之间的相似度矩阵,即:

S = \begin{bmatrix} Q_1 \cdot K_1^T & Q_1 \cdot K_2^T & Q_1 \cdot K_3^T & Q_1 \cdot K_4^T \\ Q_2 \cdot K_1^T & Q_2 \cdot K_2^T & Q_2 \cdot K_3^T & Q_2 \cdot K_4^T \\ Q_3 \cdot K_1^T & Q_3 \cdot K_2^T & Q_3 \cdot K_3^T & Q_3 \cdot K_4^T \\ Q_4 \cdot K_1^T & Q_4 \cdot K_2^T & Q_4 \cdot K_3^T & Q_4 \cdot K_4^T \end{bmatrix}

然后,我们应用 因果 Mask(Causal Mask)

S' = S + \text{mask}

Mask 让右上角(未来 token 相关的部分)变成 -\infty

\begin{bmatrix} S_{1,1} & -\infty & -\infty & -\infty \\ S_{2,1} & S_{2,2} & -\infty & -\infty \\ S_{3,1} & S_{3,2} & S_{3,3} & -\infty \\ S_{4,1} & S_{4,2} & S_{4,3} & S_{4,4} \end{bmatrix}

然后计算 softmax:

A = \text{softmax}(S')

由于 e^{-\infty} = 0,所有未来 token 相关的注意力分数都变成 0

A = \begin{bmatrix} 1 & 0 & 0 & 0 \\ \text{non-zero} & \text{non-zero} & 0 & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & 0 \\ \text{non-zero} & \text{non-zero} & \text{non-zero} & \text{non-zero} \end{bmatrix}

最后,我们计算:

\text{Output} = A V

由于未来 token 的注意力权重是 0,它们的 V 在计算中被忽略。因此,下三角部分(历史 token 之间的注意力)完全不受未来 token 影响。

结论 2未来 token 的信息不会影响下三角部分的 Attention 计算。

4. 为什么 Masked Attention 能防止未来信息泄露?

你可能会问:

即使有 Mask,计算 Attention 之前,我们不是还是用到了整个序列的 Q, K, V 吗?未来 token 的 Q, K, V 不是已经算出来了吗?

的确,每个 token 的 Q, K, V 是独立计算的,但 Masked Attention 确保了:

  1. 计算 Q, K, V 时,每个 token 只依赖于它自己的输入

    • Q_i, K_i, V_i只来自 token i,不会用到未来的信息

    • 未来的 token 并不会影响当前 token 的 Q, K, V

  2. Masked Softmax 阻止了未来 token 的影响

    • 虽然 Q, K, V 都计算了,但 Masking 让未来 token 的注意力分数变为 0,确保计算出的 Attention 结果不包含未来信息。

最终,当前 token 只能看到过去的信息,未来的信息被完全屏蔽!

5. 训练时使用 Masked Attention 的必要性

Masked Attention 的一个关键作用是 让训练阶段和推理阶段保持一致

  • 训练时:模型学习如何根据 历史 token 预测 下一个 token,确保生成文本时符合自然语言流向。

  • 推理时:模型生成每个 token 后,仍然只能访问过去的 token,而不会看到未来的 token。

如果 训练时没有 Masked Attention,模型会学习到“作弊”策略,直接利用未来信息进行预测。但在推理时,模型无法“偷看”未来的信息,导致生成质量急剧下降。

6. 结论

Masked Attention 是 LLM 训练的核心机制之一,其作用在于:

  • 确保当前 token 只能访问过去的 token,不会泄露未来信息
  • 让训练阶段与推理阶段保持一致,避免模型在推理时“失效”
  • 利用因果 Mask 让 Transformer 具备自回归能力,学会按序生成文本

Masked Attention 本质上是 Transformer 训练过程中对信息流动的严格约束,它确保了 LLM 能够正确学习自回归生成任务,是大模型高质量文本生成的基础。

http://www.dtcms.com/a/99669.html

相关文章:

  • 408 计算机网络 知识点记忆(1)
  • 代码随想录刷题day53|(二叉树篇)106.从中序与后序遍历序列构造二叉树(▲
  • 如何使用 Bash 脚本自动化清理 Nacos 日志文件
  • Postman 集合如何快速分享给团队?
  • 树莓派5学习踩坑指南1--摄像头识别,SSH VNC远程连接,忘记密码重新登录
  • SHELL 三剑客
  • ModbusTCP协议报文详细分析
  • 安卓开发之LiveData与DataBinding
  • Next.js 项目生产构建优化
  • 【leetcode hot 100 45】跳跃游戏Ⅱ
  • 第三百八十九节 JavaFX教程 - JavaFX WebEngine
  • uniapp-小程序地图展示
  • C++的模板(十四):更多的自动内存管理
  • AI的未来在手机里!
  • Spring Data审计利器:@LastModifiedDate详解(依赖关系补充篇)!!!
  • springBoot与ElementUI配合上传文件
  • Vue2——常用指令总结、指令修饰符、v-model原理、computed计算属性、watch监听器、ref和$refs
  • Elasticsearch(ES)的经典面试题及其答案
  • 深度对比:DeepSeek vs OpenAI 核心技术指标
  • Matlab安装tdms插件
  • Numpy用法(三)
  • QT操作Excel
  • 【 <二> 丹方改良:Spring 时代的 JavaWeb】之 Spring Boot 中的缓存技术:使用 Redis 提升性能
  • NodeJs之http模块
  • 学成在线--day02
  • 深度学习篇---模型训练评估参数
  • Tabby二:使用笔记 - 保姆级教程
  • C#的CSV 在8859-1下中乱码和技巧
  • 猜猜我用的是哪个大模型?我的世界游戏界面简单的模拟效果
  • 网络华为HCIA+HCIP 策略路由,双点双向