当前位置: 首页 > news >正文

注意力机制从理论到实践:注意力提示、汇聚与评分函数

目录

注意力机制:是什么?有啥用?

一、注意力机制是什么?

二、注意力机制有啥用?

三、注意力机制的核心逻辑(“关注什么→怎么关联→如何融合”)

1. 明确 “关注目标”(Query 的定义)

2. 计算 “关联强度”(注意力权重的量化)

3. 归一化 “重要性”(注意力权重的归一化)

4. 聚焦融合关键信息(注意力汇聚)

详细介绍三个技术点

一、注意力提示(Attention Cues)

 理论

 实践

1.完整代码

2.实践结果

二、注意力汇聚(Attention Aggregation)

理论

1、核心步骤:

2、示例:

进一步解释

3、分类详解

一、平均汇聚(Average Pooling)

二、非参数注意力汇聚(Non-parametric Attention Pooling)

三、带参数注意力汇聚(Parametric Attention Pooling)

实践

1、完整代码

2、实践结果

1、平均汇聚

 2、非参数注意力汇聚

3、带参数注意力汇聚

三、 注意力评分函数(Attention Scoring Function)

理论

1、缩放点积(Scaled Dot-Product)

2、加性评分(Additive Score)

3、核心对比与总结

实践

1、完整代码

2、实验结果


注意力机制:是什么?有啥用?

一、注意力机制是什么?

注意力机制(Attention Mechanism)是一种模拟人类注意力分配方式的计算框架,核心思想是:在处理信息时,有选择地聚焦于关键信息,忽略次要信息,从而高效利用计算资源并提升任务性能。

它的本质是一种 “权重分配机制”:对于输入的一系列信息(如文本中的词语、图像中的像素、序列中的元素),通过计算每个信息的 “重要性权重”,让模型在输出时更关注权重高的关键信息,权重低的信息则被弱化。

二、注意力机制有啥用?

注意力机制的核心价值在于解决 “信息过载” 和 “长距离依赖” 问题,具体作用体现在以下几个方面:

  1. 聚焦关键信息,提升效率
    现实场景中,输入数据往往包含大量冗余信息(例如,一句话中可能只有几个关键词决定语义,一张图片中只有部分区域是核心目标)。注意力机制通过分配权重,让模型优先处理关键信息,减少对无关信息的计算,提升处理效率。

    示例:在机器翻译 “我爱自然语言处理” 时,模型会更关注 “爱”“自然语言处理” 等核心词,而非虚词 “我”。

  2. 捕捉长距离依赖关系
    在序列数据(如文本、时间序列)中,远距离元素可能存在重要关联(例如,句子中 “他” 可能指代前文的 “小明”)。传统模型(如 RNN)处理长序列时容易丢失远距离信息,而注意力机制通过直接计算任意两个元素的关联权重,能精准捕捉这种长距离依赖。

    示例:在文本理解中,注意力机制可让模型明确 “他” 与 “小明” 的指代关系,即使两者相隔多个句子。

  3. 增强模型可解释性
    注意力机制的权重分布可直观展示模型关注的信息,帮助人类理解模型决策逻辑。例如:

    • 机器翻译中,通过注意力权重可看到输出词对应输入的哪个词;
    • 图像分类中,权重热力图可显示模型关注的图像区域(如识别 “猫” 时,权重集中在猫的头部和身体)。
  4. 适应多模态任务
    在图文匹配、视频字幕等多模态任务中,注意力机制能动态关联不同模态的信息(如文本中的 “红色” 与图像中的红色区域),实现跨模态的信息融合。

  5. 成为 Transformer 等模型的核心组件
    注意力机制是 Transformer 架构的核心(如 BERT、GPT 等大语言模型均基于 Transformer),其并行计算能力克服了 RNN 的序列依赖限制,使得模型能高效处理超长序列,推动了自然语言处理等领域的突破性进展。

三、注意力机制的核心逻辑(“关注什么→怎么关联→如何融合”)

1. 明确 “关注目标”(Query 的定义)

核心技术:自顶向下 / 自底向上的注意力提示(Cues)

  • 自顶向下提示:由任务目标定义 “关注方向”(如翻译时,解码器当前状态作为 Query,代表 “需要翻译的词”)。
  • 自底向上提示:由输入数据的显著性特征驱动(如文本中高频词、图像中高对比度区域)。
  • 作用:确定 “Query”(当前任务的关注点),回答 “模型应该关注什么”。
2. 计算 “关联强度”(注意力权重的量化)

核心技术:注意力评分函数(Scoring Function)

  • 功能:计算 Query 与输入中每个信息单元(Key)的相似度,量化 “关联强度”。
  • 常见实现:
    • 缩放点积(Scaled Dot-Product):score(q,k) = \frac{q \cdot k}{\sqrt{d_k}}(Transformer 核心,高效并行)。
    • 加性评分(Additive Score):score(q,k) = v^T \tanh(W_q q + W_k k)(Bahdanau 注意力采用,适合维度不匹配场景)。
    • 余弦相似度:衡量向量方向的一致性(适合语义匹配)。
  • 作用:将 “关联强度” 转化为原始分数,为后续权重归一化做准备。
3. 归一化 “重要性”(注意力权重的归一化)

核心技术:Softmax 归一化

  • 功能:将原始评分转化为总和为 1 的注意力权重(\(\alpha_i\)),确保权重可直接用于加权融合。
  • 公式:\alpha_i = \text{Softmax}(score(q,k_i)) = \frac{\exp(score(q,k_i))}{\sum_j \exp(score(q,k_j))}
  • 作用:让权重满足 “重要性占比” 的物理意义(权重越高,对应信息越重要)。
4. 聚焦融合关键信息(注意力汇聚)

核心技术:加权求和(Weighted Aggregation)

  • 功能:将输入信息(Value)按注意力权重加权求和,得到 “聚焦于 Query 的融合结果”。
  • 公式:\text{Attention}(q, K, V) = \sum_i \alpha_i \cdot v_iv_i为与k_i对应的信息内容)。
  • 作用:输出由关键信息主导的结果,实现 “关注重要信息、忽略次要信息”。

详细介绍三个技术点

一、注意力提示(Attention Cues)

 理论

注意力提示是引导模型 “关注什么” 的信号,本质是驱动注意力分配的线索。人类注意力的分配受两种提示影响,机器学习中也借鉴了这一机制:

  • 非自主性提示(自下而上的线索) 由输入数据本身的显著特征触发,无需外部目标引导。例如:

    • 图像中颜色鲜艳的物体(如红色苹果在绿色树叶中);
    • 文本中高频出现的关键词(如新闻中的 “地震”)。 模型通过感知输入的 “显著性” 自动聚焦于这些特征。
  • 自主性提示(自上而下的线索)任务目标或外部指令驱动,引导模型主动关注特定信息。例如:

    • 机器翻译中,翻译 “猫坐在垫子上” 时,“猫” 作为主语需重点关注;
    • 目标检测中,任务要求 “找到图像中的汽车”,模型会主动搜索汽车的特征(轮子、车窗等)。

作用注意力提示决定了模型 “应该关注输入的哪些部分”,是注意力机制的 “导向标”

 实践

1.完整代码
"""
文件名: 10.1 注意力机制
作者: 墨尘
日期: 2025/7/16
项目名: dl_env
备注: 注意力权重可视化工具 - 使用热图展示注意力分布
"""
# 导入必要的库
import matplotlib.pyplot as plt  # 用于绘图
import torch                     # PyTorch深度学习框架
from d2l import torch as d2l     # 动手学深度学习(Dive into Deep Learning)的辅助库#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),cmap='Reds'):"""显示矩阵热图 - 用于可视化注意力权重参数:matrices: 待可视化的矩阵张量,形状为(行数, 列数, 高度, 宽度)xlabel: x轴标签ylabel: y轴标签titles: 每个子图的标题列表(可选)figsize: 每个子图的大小,元组(宽, 高)cmap: 颜色映射,默认使用红色系('Reds')"""d2l.use_svg_display()  # 使用SVG格式显示图形,提高清晰度# 获取矩阵的行数和列数,用于创建子图网格num_rows, num_cols = matrices.shape[0], matrices.shape[1]# 创建子图网格,设置共享坐标轴以保证对齐fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,sharex=True, sharey=True, squeeze=False)# 遍历每个矩阵并绘制热图for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):# 使用imshow绘制热图,将PyTorch张量转为NumPy数组并分离计算图pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)# 只在最后一行设置x轴标签,避免重复if i == num_rows - 1:ax.set_xlabel(xlabel)# 只在第一列设置y轴标签,避免重复if j == 0:ax.set_ylabel(ylabel)# 如果提供了标题,则设置子图标题if titles:ax.set_title(titles[j])# 添加统一的颜色条,缩放到0.6倍以适应布局fig.colorbar(pcm, ax=axes, shrink=0.6);if __name__ == '__main__':# 创建一个10×10的单位矩阵作为注意力权重示例# 单位矩阵表示每个查询只关注对应的键(对角线元素为1,其余为0)attention_weights = torch.eye(10).reshape((1, 1, 10, 10))# 调用函数可视化注意力权重,x轴为"Keys",y轴为"Queries"show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')# 显示图形(在某些环境中可能不需要此行)plt.show()
2.实践结果

二、注意力汇聚(Attention Aggregation)

理论

注意力汇聚是根据注意力权重对输入信息进行聚合的过程。其核心逻辑是:对输入中 “更重要” 的部分赋予更高权重,再通过加权求和整合信息

1、核心步骤:
  • 确定 “查询(Query)”“键(Key)”“值(Value)”

    • query:当前任务的 “关注点”(如翻译时的解码器当前状态);
    • key:输入中可供匹配的 “特征标识”(如编码器的所有隐藏状态);
    • value:与 key 对应的 “具体信息内容”(如编码器隐藏状态对应的语义)。
  • 计算注意力权重:通过 “注意力评分函数”计算 query 与每个 key 的相似度,得到权重(权重之和为 1,满足归一化)。

  • 加权聚合:将 value 按权重求和,得到 “聚焦于 query 的聚合结果”(即注意力输出)。

2、示例:

假设输入序列为 “[猫,在,垫子,上]”,当前 query 是 “猫”,key 是每个词的特征,value 是每个词的语义向量。若 “猫” 与 “垫子” 的相似度高(权重 0.6),与 “在” 的相似度低(权重 0.1),则聚合结果更偏向 “猫” 和 “垫子” 的语义。

进一步解释

 1. Key(每个词的特征):用于计算相似度的 “标识性特征”

“特征” 在这里指的是每个词的低层次、结构化的表示,用于快速判断词与词之间的关联(即作为计算注意力权重的 “依据”)。

  • 具体含义:特征是词的 “标识属性”,可能包含词性、词形、基础语法功能等结构化信息。例如:

    • “猫” 的特征:[名词,动物,单数,主语]
    • “垫子” 的特征:[名词,物品,单数,宾语]
    • “在” 的特征:[介词,表位置,连接作用]
    • “上” 的特征:[方位词,表位置,补充作用]
  • 作用:特征(Key)的核心是 “可比较性”—— 通过这些结构化特征,模型能快速计算 “猫” 与其他词的相似度(例如 “猫” 和 “垫子” 都是名词,且在句中构成 “动作对象” 关系,因此特征相似度高,对应注意力权重 0.6;“猫” 和 “在” 词性差异大,特征相似度低,权重 0.1)。

2. Value(每个词的语义向量):用于聚合的 “深层语义表示”

“语义向量” 是每个词的高层次、抽象的语义表示,包含词的含义、上下文关联等深层信息,是最终参与 “注意力汇聚” 的核心内容。

  • 具体含义:语义向量是通过嵌入层(如 Word2Vec、BERT 嵌入)或编码器生成的稠密向量,编码了词的内涵和语境意义。例如:

    • “猫” 的语义向量:包含 “猫 = 小型哺乳动物 + 宠物 + 会跑跳” 等语义信息。
    • “垫子” 的语义向量:包含 “垫子 = 柔软物品 + 用于坐卧 + 放置于地面” 等语义信息。
    • “在” 和 “上” 的语义向量:包含 “表示位置关系” 的语义,但具体内涵较弱(因此权重低)。
  • 作用:语义向量(Value)是注意力机制最终 “聚焦” 的内容。当计算出 “猫” 对 “垫子” 的权重为 0.6、对自身权重假设为 0.3(剩余 0.1 给 “在” 和 “上”)时,聚合结果为: 聚合向量 = 0.3×猫的语义向量 + 0.6×垫子的语义向量 + 0.1×其他词的语义向量 这个向量会更偏向 “猫” 和 “垫子” 的核心语义,从而准确表达 “猫与垫子的关联”(如 “猫在垫子上” 的核心语义)。

总结

  • Key(特征):是词的 “标识性属性”,用于计算注意力权重(解决 “谁和谁相似”);
  • Value(语义向量):是词的 “深层语义内容”,用于最终的加权聚合(解决 “相似的词贡献什么语义”)。

公式抽象:设注意力权重为\alpha_i(对应第i个 value),则聚合结果为:

注意力汇聚的核心是通过权重分配实现 “重要信息强化,次要信息弱化”

3、分类详解
一、平均汇聚(Average Pooling)

平均汇聚是最基础的 “汇聚”(Pooling)操作,核心是对输入的多个元素取算术平均值,将 “多元素” 压缩为 “单元素”,实现信息的聚合

核心思想

忽略输入元素的 “重要性差异”,认为所有元素同等重要,通过平均操作保留整体趋势,过滤局部噪声。

公式与示例

设输入序列为x_1, x_2, ..., x_n(每个 x_i可以是标量或向量),平均汇聚的结果为:y = \frac{1}{n} \sum_{i=1}^n x_i

  • 示例:对序列 “[猫,在,垫子,上]” 的词向量(假设每个词向量为 v_1, v_2, v_3, v_4)进行平均汇聚,结果为:y = \frac{v_1 + v_2 + v_3 + v_4}{4},即 4 个词向量的平均。

特点与应用

  • 优点:简单、无参数(无需训练)、计算高效,适合快速聚合信息。
  • 缺点:平等对待所有元素,会丢失重要元素的权重(例如 “猫” 可能比 “在” 更重要,但平均后被稀释)。
  • 应用:早期神经网络的特征聚合(如 CNN 的池化层、简单序列模型的句子表示)。
二、非参数注意力汇聚(Non-parametric Attention Pooling)

非参数注意力汇聚是注意力机制的基础形式,核心是通过计算 “查询(query)与键(key)的相似度” 动态分配权重,再加权聚合值(value),且无需要学习的参数

核心思想

  • 重要性不是固定的(区别于平均汇聚),而是由 query 与每个 key 的 “相关性” 决定:相关性越高,对应 value 的权重越大。
  • 无参数:相似度计算方式固定(如高斯核、余弦相似度),不涉及可学习参数。

公式与步骤

设 query 为 q,键值对为 (k_1, v_1), (k_2, v_2), ..., (k_n, v_n),非参数注意力汇聚的步骤为:

  1. 计算相似度:用固定函数(如高斯核)计算 q 与每个k_i的相似度 s(q, k_i)。 常用高斯核(Gaussian kernel):s(q, k_i) = \exp\left(-\frac{1}{2\sigma^2} \| q - k_i \|^2 \right)\sigma是超参数,手动设定而非学习)

  2. 权重归一化:用 softmax 将相似度转换为和为 1 的权重 \alpha_i = \frac{\exp(s(q, k_i))}{\sum_{j=1}^n \exp(s(q, k_j))}

  3. 加权聚合:用权重加权求和 value,得到汇聚结果:y = \sum_{i=1}^n \alpha_i v_i

示例

以序列 “[猫,在,垫子,上]” 为例,设:

  • query q 是 “猫” 的特征向量;
  • k_1, k_2, k_3, k_4分别是 “猫”“在”“垫子”“上” 的特征向量;
  • v_1, v_2, v_3, v_4分别是这四个词的语义向量。

若计算得相似度:s(q,k_1)=1.2, s(q,k_2)=0.3, s(q,k_3)=1.0, s(q,k_4)=0.5,则:

  • 权重\alpha_1 = \exp(1.2)/(\exp(1.2)+\exp(0.3)+\exp(1.0)+\exp(0.5)) \approx 0.4
  • 同理得 \alpha_2 \approx 0.1\alpha_3 \approx 0.3\alpha_4 \approx 0.2
  • 汇聚结果 y = 0.4v_1 + 0.1v_2 + 0.3v_3 + 0.2v_4,更偏向 “猫” 和 “垫子” 的语义。

特点与应用

  • 优点:动态分配权重,比平均汇聚更灵活,能捕捉 query 与 key 的相关性。
  • 缺点:相似度函数固定(如高斯核的 \(\sigma\) 需手动调),无法适应任务自动优化。
  • 应用:早期注意力模型(如 Bahdanau 注意力的简化版)、无监督学习中的特征聚合。
三、带参数注意力汇聚(Parametric Attention Pooling)

带参数注意力汇聚在非参数版本的基础上引入可学习参数,让模型通过训练自动优化 “相似度计算方式”,更适应任务需求

核心思想

  • 相似度计算不再固定,而是由可学习参数(如权重矩阵、偏置)定义,模型通过训练学习 “如何衡量相关性”。
  • 权重 \alpha_i 不仅依赖 query 和 key,还依赖参数,能更精准地捕捉任务相关的重要性。

典型形式与公式

常用的带参数注意力包括加性注意力(Additive Attention) 和缩放点积注意力(Scaled Dot-Product Attention),以加性注意力为例:

  1. 相似度计算:用小型神经网络(含参数)计算相似度:s(q, k_i) = \mathbf{w}_v^T \tanh\left( \mathbf{W}_q q + \mathbf{W}_k k_i + b \right) (\mathbf{W}_q, \mathbf{W}_k是权重矩阵,\mathbf{w}_v是权重向量,b 是偏置,均为可学习参数)

  2. 权重归一化:同非参数版本,用 softmax 得\alpha_i\alpha_i = \frac{\exp(s(q, k_i))}{\sum_{j=1}^n \exp(s(q, k_j))}

  3. 加权聚合:同前,y = \sum_{i=1}^n \alpha_i v_i

与非参数的关键区别

  • 非参数:相似度函数固定(如高斯核),无参数,灵活性低。
  • 带参数:相似度函数由参数定义,可学习,能适应不同任务(如机器翻译中 “猫” 与 “垫子” 的相关性可能比 “在” 更高,参数会学习强化这一点)。

实践

1、完整代码
"""
文件名: 10.2
作者: 墨尘
日期: 2025/7/16
项目名: dl_env
备注: 注意力汇聚的核回归示例 - 对比平均汇聚、非参数注意力汇聚、带参数注意力汇聚的效果
"""
# 导入必要库
import torch  # PyTorch深度学习框架,用于张量计算和模型构建
from torch import nn  # PyTorch的神经网络模块,含激活函数、损失函数等
from d2l import torch as d2l  # 动手学深度学习的辅助库,含绘图和数据处理工具
import matplotlib.pyplot as plt  # 用于图形显示def plot_kernel_reg(y_hat):"""可视化核回归的预测结果与真实值对比参数:y_hat: 模型的预测值张量,形状为(测试样本数,)"""# 绘制真实值(Truth)和预测值(Pred)曲线d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],xlim=[0, 5], ylim=[-1, 5])  # x轴范围0-5,y轴范围-1-5# 绘制训练样本点(散点图),alpha=0.5设置透明度d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);class NWKernelRegression(nn.Module):"""带参数的注意力汇聚模型(核回归) - Nadaraya-Watson核回归的参数化版本引入可学习参数w,用于调整注意力权重的缩放程度"""def __init__(self, **kwargs):super().__init__(** kwargs)  # 继承nn.Module的初始化方法# 定义可学习参数w,初始值随机,需要计算梯度(requires_grad=True)self.w = nn.Parameter(torch.rand((1,), requires_grad=True))def forward(self, queries, keys, values):"""前向传播:计算带参数的注意力汇聚结果参数:queries: 查询张量,形状为(查询样本数,)keys: 键张量,形状为(查询样本数, 键值对数量)values: 值张量,形状为(查询样本数, 键值对数量)返回:注意力汇聚的预测结果,形状为(查询样本数,)"""# 将查询重复为与键同形状:(查询数, 键值对数量),每一行都是相同的查询queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))# 计算注意力权重:基于查询与键的距离,用softmax归一化# 公式:权重 = softmax( -((查询-键)*w)^2 / 2 ),w为可学习参数,控制距离的缩放self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w) **2 / 2, dim=1)  # dim=1表示按行归一化# 批量矩阵乘法:注意力权重(查询数,1,键值对数) × 值(查询数,键值对数,1) → 预测结果return torch.bmm(self.attention_weights.unsqueeze(1),  # 增加维度以满足bmm要求values.unsqueeze(-1)).reshape(-1)  # 压缩维度为(查询数,)if __name__ == '__main__':# -------------------------- 1. 生成训练数据和测试数据 --------------------------n_train = 50  # 训练样本数量# 生成训练输入x_train:0-5之间的随机数,排序后便于后续可视化x_train, _ = torch.sort(torch.rand(n_train) * 5)# 定义真实函数f(x) = 2*sin(x) + x^0.8(非线性函数,适合核回归测试)def f(x):return 2 * torch.sin(x) + x** 0.8# 生成训练输出y_train:真实值+噪声(模拟真实数据中的观测误差)y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 噪声均值0,标准差0.5# 生成测试输入x_test:0-5之间间隔0.1的均匀采样(共50个点)x_test = torch.arange(0, 5, 0.1)y_truth = f(x_test)  # 测试数据的真实输出(无噪声)n_test = len(x_test)  # 测试样本数量(50)print(f"测试样本数量: {n_test}")  # 输出:50# -------------------------- 2. 平均汇聚(基准模型) --------------------------# 平均汇聚:所有测试样本的预测值都等于训练数据的平均值(忽略输入特征,最简单的汇聚方式)y_hat = torch.repeat_interleave(y_train.mean(), n_test)  # 重复平均值n_test次# 可视化平均汇聚的预测结果与真实值对比plot_kernel_reg(y_hat)plt.show()  # 显示图形:平均汇聚效果差,无法捕捉非线性趋势# -------------------------- 3. 非参数注意力汇聚 --------------------------# 构造查询矩阵X_repeat:形状(n_test, n_train),每行都是相同的测试输入(查询)X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))# 计算注意力权重:基于查询与键(训练输入x_train)的欧式距离,用高斯核度量相似度# 公式:权重 = softmax( -(查询-键)^2 / 2 ),无参数,完全由数据驱动attention_weights = nn.functional.softmax(-(X_repeat - x_train) **2 / 2, dim=1)# 非参数注意力汇聚:预测值 = 注意力权重 × 训练输出(矩阵乘法实现加权求和)y_hat = torch.matmul(attention_weights, y_train)# 可视化非参数注意力汇聚的结果plot_kernel_reg(y_hat)plt.show()  # 显示图形:预测更贴近真实值,能捕捉局部趋势# 可视化非参数注意力权重的热图# 热图形状:(1,1,n_test,n_train),行=测试样本,列=训练样本,颜色越深权重越大d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs(键)',ylabel='Sorted testing inputs(查询)')plt.show()  # 显示热图:测试样本与接近的训练样本权重更高(局部注意力)# -------------------------- 4. 带参数注意力汇聚(模型训练与测试) --------------------------# 演示批量矩阵乘法(bmm)的维度要求:用于批量处理多个样本的矩阵乘法X = torch.ones((2, 1, 4))  # 批量=2,每个矩阵1×4Y = torch.ones((2, 4, 6))  # 批量=2,每个矩阵4×6print(torch.bmm(X, Y).shape)  # 输出:torch.Size([2, 1, 6]),符合矩阵乘法规则# 演示注意力权重与值的批量乘法:权重(2,1,10) × 值(2,10,1) → 结果(2,1,1)weights = torch.ones((2, 10)) * 0.1  # 均匀权重,每行和为1values = torch.arange(20.0).reshape((2, 10))  # 示例值:0-19print(torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1)))  # 输出每行平均值# 构造训练用的键值对(排除自身样本,避免过拟合)# X_tile:形状(n_train, n_train),每行都是相同的训练输入(作为查询)X_tile = x_train.repeat((n_train, 1))# Y_tile:形状(n_train, n_train),每行都是相同的训练输出(作为值)Y_tile = y_train.repeat((n_train, 1))# 键:排除对角线元素(自身样本),形状(n_train, n_train-1)keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))# 值:对应排除自身样本的训练输出,形状(n_train, n_train-1)values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))# 初始化带参数的注意力模型、损失函数和优化器net = NWKernelRegression()  # 带可学习参数w的模型loss = nn.MSELoss(reduction='none')  # 均方误差损失,保留每个样本的损失trainer = torch.optim.SGD(net.parameters(), lr=0.5)  # 随机梯度下降优化器,学习率0.5animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])  # 可视化训练损失# 训练模型(5个epoch)for epoch in range(5):trainer.zero_grad()  # 清除梯度# 前向传播:用训练数据的查询、键、值计算预测l = loss(net(x_train, keys, values), y_train)l.sum().backward()  # 损失求和并反向传播trainer.step()  # 更新参数wprint(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')  # 输出当前损失animator.add(epoch + 1, float(l.sum()))  # 记录损失用于绘图plt.show()  # 显示训练损失曲线:损失逐渐下降,参数w收敛# 用训练好的模型预测测试数据# 构造测试用的键和值:键=训练输入x_train(形状(n_test, n_train)),值=训练输出y_trainkeys = x_train.repeat((n_test, 1))values = y_train.repeat((n_test, 1))# 带参数注意力汇聚的预测结果(detach()脱离计算图,仅用于可视化)y_hat = net(x_test, keys, values).unsqueeze(1).detach()# 可视化带参数注意力汇聚的预测结果plot_kernel_reg(y_hat)plt.show()  # 显示图形:预测精度进一步提升,更贴近真实值# 可视化带参数模型的注意力权重热图d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs(键)',ylabel='Sorted testing inputs(查询)')plt.show()  # 显示热图:权重更集中,参数w优化了距离的缩放,增强了局部关联性
2、实践结果
1、平均汇聚

 2、非参数注意力汇聚

注意力权重

3、带参数注意力汇聚

注意力权重

三、 注意力评分函数(Attention Scoring Function)

理论

注意力评分函数是计算 query 与 key 相似度的工具,其输出经归一化(如 Softmax)后即为注意力权重\alpha_i。常见的评分函数有两类:

1、缩放点积(Scaled Dot-Product)
  • 原理:直接计算 query 与 key 的内积,再除以向量维度的平方根(避免维度过高导致内积值过大,影响 Softmax 稳定性)。
  • 注意:缩放点积注意力是 Transformer 模型采用的核心注意力机制,通过向量内积计算相似度,并引入缩放因子解决维度灾难问题
  • 公式\text{score}(q, k) = \frac{q \cdot k}{\sqrt{d}} 其中d是 query/key 的维度,q \cdot k是向量内积(衡量相似度)。
  • 优点:计算效率极高(仅需矩阵乘法和 softmax,适合 GPU 并行);无额外参数(仅依赖内积)。
  • 缺点:当查询和键的维度 d 差异较大时,内积可能无法有效衡量相似度。
  • 适用场景:高维向量、大规模批量数据(如 Transformer、BERT、GPT 等大语言模型)。
2、加性评分(Additive Score)
  • 原理:通过线性变换将 query 和 key 映射到同一空间,再用 tanh 激活后与权重向量相乘,输出标量评分。
  • 注意:加性注意力通过多层感知机(MLP) 计算查询与键的相似度,适用于查询和键维度不同的场景,由 Bahdanau 在机器翻译模型中首次提出
  • 公式\text{score}(q, k) = v \cdot \tanh(W_q q + W_k k)其中W_q, W_k是可学习的权重矩阵,v是可学习的向量。
  • 特点:适用于 query 与 key 维度不同的场景,计算复杂度高于点积,但灵活性更强。

评分函数的选择需根据任务场景(如维度是否一致、效率要求)决定,其核心是量化 query 与 key 的关联程度

3、核心对比与总结
维度缩放点积注意力加性注意力
评分函数内积 + 缩放(\frac{\mathbf{q} \cdot \mathbf{k}}{\sqrt{d}}MLP(\mathbf{w}_v^T \tanh(\mathbf{W}_q \mathbf{q} + \mathbf{W}_k \mathbf{k})
参数数量0(无额外参数)h(d_q + d_k + 1)(依赖隐藏层维度 h)
计算复杂度O(d)(高效,矩阵乘法并行)O(h)(h 通常为 d 或 2d,复杂度更高)
适用维度要求 \mathbf{q}, \mathbf{k} 维度相同支持 \mathbf{q}, \mathbf{k} 维度不同
典型应用Transformer、大语言模型(GPT/BERT)Bahdanau 机器翻译、跨模态注意力
核心优势效率极高,适合大规模数据灵活性高,适合复杂相似度建模

实践

1、完整代码
"""
文件名: 10.3
作者: 墨尘
日期: 2025/7/16
项目名: dl_env
备注: 实现加性注意力(Additive Attention)和缩放点积注意力(Scaled Dot-Product Attention),包含掩码softmax功能,用于处理序列长度不一致的场景
"""
import math  # 用于数学运算(如缩放点积中的平方根)
import torch  # PyTorch框架,用于张量计算和模型构建
from torch import nn  # 神经网络模块,含线性层、Dropout等组件
from d2l import torch as d2l  # 动手学深度学习辅助库,含绘图和序列处理工具
import matplotlib.pyplot as plt  # 用于图形显示#@save  # 标记为可保存的函数,便于后续调用
def masked_softmax(X, valid_lens):"""带掩码的softmax函数:对序列中超出有效长度的部分进行掩码(权重置为0)参数:X: 输入张量,形状通常为(batch_size, num_queries, num_keys),表示查询与键的相似度分数矩阵valid_lens: 有效长度张量,形状为(batch_size,)或(batch_size, num_queries)表示每个查询/样本的有效键值对数量(超出部分需掩码)返回:掩码后的softmax结果,形状与X一致,无效位置的权重被置为0"""if valid_lens is None:# 若没有有效长度限制,直接对最后一维做softmaxreturn nn.functional.softmax(X, dim=-1)else:shape = X.shape  # 获取输入形状,用于后续恢复维度# 处理有效长度维度:若为1D(每个样本一个有效长度),则重复为与查询数匹配的2Dif valid_lens.dim() == 1:# 例如:valid_lens = [2,6],重复后变为[2,2,...,6,6,...](每个查询对应相同的样本有效长度)valid_lens = torch.repeat_interleave(valid_lens, shape[1])else:# 若为2D(每个查询有独立有效长度),则展平为1D便于处理valid_lens = valid_lens.reshape(-1)# 1. 将X展平为(batch_size * num_queries, num_keys),便于统一处理掩码# 2. 使用d2l.sequence_mask对超出有效长度的位置填充-1e6(softmax后接近0)X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)# 恢复原形状并对最后一维做softmax,确保无效位置权重为0return nn.functional.softmax(X.reshape(shape), dim=-1)#@save
class AdditiveAttention(nn.Module):"""加性注意力模型(Bahdanau注意力):通过MLP计算查询与键的相似度"""def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):"""初始化加性注意力参数参数:key_size: 键的特征维度query_size: 查询的特征维度num_hiddens: 隐藏层维度(MLP的中间维度)dropout: Dropout概率,用于防止过拟合**kwargs: 其他可选参数"""super(AdditiveAttention, self).__init__(** kwargs)# 键的线性投影层:将键从key_size映射到num_hiddens,无偏置self.W_k = nn.Linear(key_size, num_hiddens, bias=False)# 查询的线性投影层:将查询从query_size映射到num_hiddens,无偏置self.W_q = nn.Linear(query_size, num_hiddens, bias=False)# 输出层:将隐藏层特征映射到标量(相似度分数),无偏置self.w_v = nn.Linear(num_hiddens, 1, bias=False)# Dropout层:训练时随机丢弃部分注意力权重,降低过拟合风险self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens):"""前向传播:计算加性注意力的输出参数:queries: 查询张量,形状为(batch_size, num_queries, query_size)keys: 键张量,形状为(batch_size, num_kv_pairs, key_size)(num_kv_pairs为键值对数量)values: 值张量,形状为(batch_size, num_kv_pairs, value_size)valid_lens: 有效长度张量,形状为(batch_size,)或(batch_size, num_queries)返回:注意力加权后的输出,形状为(batch_size, num_queries, value_size)"""# 1. 线性投影:将查询和键映射到隐藏空间queries, keys = self.W_q(queries), self.W_k(keys)  # 形状分别为:(batch_size, num_queries, num_hiddens)、(batch_size, num_kv_pairs, num_hiddens)# 2. 扩展维度并广播相加:模拟加性注意力的特征融合# queries.unsqueeze(2):在第2维增加维度 → (batch_size, num_queries, 1, num_hiddens)# keys.unsqueeze(1):在第1维增加维度 → (batch_size, 1, num_kv_pairs, num_hiddens)# 广播机制:相加后形状为(batch_size, num_queries, num_kv_pairs, num_hiddens)features = queries.unsqueeze(2) + keys.unsqueeze(1)features = torch.tanh(features)  # 激活函数,增强非线性表达能力# 3. 计算相似度分数:将隐藏特征映射为标量分数# self.w_v(features):形状为(batch_size, num_queries, num_kv_pairs, 1)# squeeze(-1):移除最后一维 → (batch_size, num_queries, num_kv_pairs)scores = self.w_v(features).squeeze(-1)# 4. 计算带掩码的注意力权重:超出有效长度的位置权重为0self.attention_weights = masked_softmax(scores, valid_lens)# 5. 加权聚合值:注意力权重 × 值,应用Dropout防止过拟合# 批量矩阵乘法:(batch_size, num_queries, num_kv_pairs) × (batch_size, num_kv_pairs, value_size)# → 输出形状:(batch_size, num_queries, value_size)return torch.bmm(self.dropout(self.attention_weights), values)#@save
class DotProductAttention(nn.Module):"""缩放点积注意力:通过向量内积计算相似度,引入缩放因子缓解维度问题"""def __init__(self, dropout, **kwargs):"""初始化缩放点积注意力参数:dropout: Dropout概率,用于防止过拟合**kwargs: 其他可选参数"""super(DotProductAttention, self).__init__(** kwargs)self.dropout = nn.Dropout(dropout)  # Dropout层def forward(self, queries, keys, values, valid_lens=None):"""前向传播:计算缩放点积注意力的输出参数:queries: 查询张量,形状为(batch_size, num_queries, d)(d为特征维度,需与键相同)keys: 键张量,形状为(batch_size, num_kv_pairs, d)values: 值张量,形状为(batch_size, num_kv_pairs, value_size)valid_lens: 有效长度张量,形状为(batch_size,)或(batch_size, num_queries)(可选)返回:注意力加权后的输出,形状为(batch_size, num_queries, value_size)"""d = queries.shape[-1]  # 获取特征维度d,用于缩放因子计算# 1. 计算相似度分数:查询 × 键的转置,除以√d(缩放因子)# keys.transpose(1,2):交换键的后两维 → (batch_size, d, num_kv_pairs)# 批量矩阵乘法:(batch_size, num_queries, d) × (batch_size, d, num_kv_pairs) → (batch_size, num_queries, num_kv_pairs)# 缩放:除以√d,避免高维时内积值过大导致softmax梯度消失scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)# 2. 计算带掩码的注意力权重self.attention_weights = masked_softmax(scores, valid_lens)# 3. 加权聚合值:应用Dropout后与值做批量矩阵乘法return torch.bmm(self.dropout(self.attention_weights), values)if __name__ == '__main__':# -------------------------- 1. 测试掩码softmax函数 --------------------------# 测试1:valid_lens为1D张量(每个样本一个有效长度)# 输入X形状:(2,2,4) → 2个样本,每个样本2个查询,4个键# valid_lens = [2,3] → 第1个样本有效键长度为2,第2个为3print("掩码softmax测试1(1D valid_lens):")print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))# 输出解读:每行(查询)中,超出有效长度的位置softmax结果接近0# 测试2:valid_lens为2D张量(每个查询独立有效长度)# valid_lens = [[1,3], [2,4]] → 第1个样本第1个查询有效长度1,第2个查询3;第2个样本类似print("\n掩码softmax测试2(2D valid_lens):")print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]])))# 输出解读:每个查询的有效长度独立,超出部分权重为0# -------------------------- 2. 测试加性注意力 --------------------------# 构造输入数据# queries:2个样本,每个样本1个查询,维度20 → (2,1,20)# keys:2个样本,每个样本10个键,维度2 → (2,10,2)queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))# values:2个样本,每个样本10个值,维度4(值=0~39,重复2次)→ (2,10,4)values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)valid_lens = torch.tensor([2, 6])  # 有效长度:第1个样本只关注前2个键,第2个关注前6个# 初始化加性注意力模型attention = AdditiveAttention(key_size=2,  # 键的维度query_size=20,  # 查询的维度num_hiddens=8,  # 隐藏层维度dropout=0.1  # Dropout概率)attention.eval()  # 切换到评估模式(不启用Dropout)# 前向传播计算注意力输出output = attention(queries, keys, values, valid_lens)print("\n加性注意力输出形状:", output.shape)  # 应为(2,1,4)# 可视化注意力权重热图:形状(1,1,2,10) → 2个样本,每个样本1个查询,10个键d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')# 热图解读:颜色越深表示权重越大,第1个样本仅前2个键有权重,第2个样本前6个键有权重plt.show()# -------------------------- 3. 测试缩放点积注意力 --------------------------# 构造查询:2个样本,每个样本1个查询,维度2(与键维度一致)→ (2,1,2)queries = torch.normal(0, 1, (2, 1, 2))# 初始化缩放点积注意力模型attention = DotProductAttention(dropout=0.5)attention.eval()  # 评估模式# 前向传播计算输出output = attention(queries, keys, values, valid_lens)print("\n缩放点积注意力输出形状:", output.shape)  # 应为(2,1,4)# 可视化注意力权重热图d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),xlabel='Keys', ylabel='Queries')# 热图解读:与加性注意力类似,有效长度内的键有权重,但具体分布因评分函数不同而有差异plt.show()
2、实验结果

http://www.dtcms.com/a/283266.html

相关文章:

  • HertzBeat 监控 SpringBoot 使用案例
  • elf、axf、bin的区别与转换
  • freetds 解决连接SQL SERVER报错Unexpected EOF from the server
  • 基于组学数据的药物敏感性预测模型构建与验证
  • AI时代基础入门
  • 卷积神经网络(CNN)最本质的技术
  • 离线环境中将现有的 WSL 1 升级到 WSL 2
  • list类的常用接口实现及迭代器
  • [BJDCTF2020]Cookie is so stable
  • Mybatis07-缓存
  • 正确选择光伏方案设计软件:人力成本优化的关键一步
  • 聊聊自己的新书吧
  • lustre设置用户配额
  • 同态加密赋能大模型医疗文本分析:可验证延迟压缩的融合之道
  • xss-labs靶场前八关
  • C语言基础:循环练习题
  • Linux切换到Jenkins用户解决Jenkins Host key verification failed
  • Electron实现“仅首次运行时创建SQLite数据库”
  • 大语言模型幻觉检测:语义熵揭秘
  • [Mysql] Connector / C++ 使用
  • AutoMQ 正式通过 SOC 2 Type II 认证
  • 尚庭公寓-----day1 业务功能实现
  • 八、DMSP/OLS、NPP/VIIRS等夜间灯光数据能源碳排放空间化——碳排放空间分级、空间自相关
  • PyCharm高效入门指南
  • docker+小皮面板
  • TCP 三次握手与四次挥手笔记
  • 光伏气象数据驱动设计方案优化
  • Spider的用法
  • 搭建云途YTM32B1MD1芯片VSCODE+GCC + Nijia + Cmake+Jlink开发环境
  • Python设计模式深度解析:单例模式(Singleton Pattern)完全指南