当前位置: 首页 > news >正文

模拟注意力:少量参数放大 Attention 表征能力

论文标题

SAS: Simulated Attention Score

论文地址

https://arxiv.org/pdf/2507.07694

代码

见论文附录

作者背景

摩根士丹利,斯坦福大学,微软研究院,新加坡国立大学,得克萨斯大学奥斯汀分校,香港大学

动机

多头注意力是 Transformer 的核心组件,它通过引入多组 QKV 投影来捕获不同的特征子空间,从而在机器翻译、问答等任务中取得巨大成功。研究表明,注意力头的数量对 Transformer 性能至关重要:在保证每个头的隐藏维度充分大的前提下,注意力头数越多可以使模型效果越好。但问题在于,直接增加头数或维度往往伴随着模型参数量和计算开销的剧增,这在训练和部署中代价高昂

在这里插入图片描述

目前也有一些注意力架构旨在提高计算效率,例如共享部分 K 和 V 的 MQA、GQA;使用矩阵分解的 MLA、MFA、TPA 等。但这些方法主要关注降低内存/计算成本,而非提升注意力的表达能力

于是作者希望在不显著增加参数的前提下,设计一种新的注意力架构,实现近似于使用了更多注意力头和更高每头维度的性能提升

本文方法

本文提出 SAS(Simulated Attention Score,模拟注意力分数),核心思想是在注意力计算中引入额外的映射层,将低维的头表示投射到更高维空间,以此“虚拟地”增大注意力头数和每头的隐藏维度

一、扩展注意力头

对于查询Q,其特征维度为 [B, T, H, D],分别表示 batch_size,序列长度,头数和隐藏维度。为了扩充 H,需要把其他维度拉平,得到张量 Q_0,维度为 [B * T * D, H] ;然后使用一个 H * H’ 的线性变换得到 Q_1,维度为 [B * T * D, H’],其中 H’ > H;Q_1 过一个 ReLU 引入非线性;最后再过一个 H’ * H’ 的线性层,并加上 Q_1 的残差连接

在这里插入图片描述

于是我们获得了更多的注意力头,其中残差连接的引入可以稳定训练;值得注意的是,原始头数 H 和扩展后的头数 H’ 都远小于每头的特征维度 D,所以这个两层 MLP 的参数开销相对整模型来说可以忽略不计

除了使用 MLP 来扩展维度,作者还尝试了卷积方案。具体地,将查询 Q 的维度整理成 [B * T, H, D],类似于多通道特征图,然后使用卷积变换将 H 扩展成 H’,同样地,H’ > H,最后再过第二层卷积以及残差连接

在这里插入图片描述

类似地,在 K、V 中都应用上述扩展流程

二、扩展注意力维度

直觉上,每个注意力头内部特征维度 D 越大,其能够捕获的子空间信息越丰富。因此作者进一步在 Q 和 K 上也引入了类似的维度扩展映射。这里之所以不对 V 进行扩展,是因为 V

直接决定了注意力模块的输出张量隐藏维度,扩大 V 的每头维度到 D 会导致后续前馈层的参数量大幅增加,违背了不显著增加计算量的初衷

在这里插入图片描述

三、注意力聚合

在标准多头注意力中,会将所有头的输出向量拼接,再通过一个输出投影矩阵 O 映射回模型的隐藏维度。然而,由于 SAS 对注意力头数进行了扩增,若仍按传统方式拼接势必导致输出维度变大,进而导致 O 的参数量大大增加(H * hidden 变为 H’ * hidden)。为此,作者提出了参数高效注意力聚合机制,旨在不增加输出层参数规模的情况下完成对多头输出的整合

实现过程非常简单:假设注意力头数扩展了 r 倍,即 r * H = H’,那么便把所有头划分成 r 组,每组都按照原本的计算流程与 O 相乘,得到 r 组输出结果,最后取平均作为注意力模块的最终输出传向前馈层

在这里插入图片描述

实验结果

作者在多种基准任务和数据集上对SAS进行了验证,包括语言模型预训练及下游任务评估,全面展示了SAS在准确率和效率方面的优势

一、预训练效果

下图对比了SAS与标准MHA、MQA、GQA、MLA、TPA等方法在ArXiv和Books3数据集上的表现。结果表明,无论是短序列训练(长度512)还是长序列训练(长度1024),SAS均取得了最低的验证困惑度

在这里插入图片描述

除了取得更好的性能,SAS还加速了模型的收敛。作者报告,在 Books3 数据集、序列长度512的训练中,MHA模型在5万步时达到29.86的验证困惑度,而SAS模型在3万步时就达到了相近的30.49,即 SAS 可以节约 40% 左右的计算资源

此外,作者还在更大的训练长度、更大的模型尺寸上做了验证,结果表明相比于其他注意力机制 SAS 具备稳定的优势

二、下游任务效果

作者评测了在多个下游任务基准(ARC、HellaSwag、PIQA、ScIQ、SocialIQA、WinoGrande)上 SAS 与其他注意力模型的效果,可见在多种参数量、训练数据量的实验设置下,SAS 大部分情况下都表现出了最优性能

在这里插入图片描述

http://www.dtcms.com/a/276249.html

相关文章:

  • hiredis: 一个轻量级、高性能的 C 语言 Redis 客户端库
  • 深入解析C#接口实现的两种核心技术:派生继承 vs 显式实现
  • Java 21 虚拟线程
  • 浏览器宏任务的最小延时:揭开setTimeout 4ms的神话
  • java中的main方法
  • window7,windows10,windows11种系统之间实现打印机共享
  • 创客匠人:从定位逻辑看创始人 IP 如何驱动 IP 变现
  • CompareFace使用
  • Kimi K2万亿参数开源模型原理介绍
  • 【读书笔记】《C++ Software Design》第二章:The Art of Building Abstractions
  • Ruby如何采集直播数据源地址
  • OpenEuler操作系统中检测插入的USB设备并自动挂载
  • 【数据结构】反射、枚举 和 lambda表达式
  • Golang 面向对象(封装、继承、多态)
  • 【C语言】指针进阶:指针和数组
  • 手把手教你用YOLOv10打造智能垃圾检测系统
  • 第七章应用题
  • Geant4 安装---Ubuntu
  • 一篇博客学习Lua_安装使用+语法详解
  • Lua ADB 接口文档
  • RMSNorm实现
  • 2.单例模式
  • Vim的magic模式
  • blender uv小技巧
  • Python 包管理新时代:深入了解 `uv` 的使用与实践
  • OpenVela之模拟器调试
  • 【kubernetes】--Controller(StatefulSet)
  • 【PTA数据结构 | C语言版】链式队列的3个操作
  • Git常用命令一览
  • pyqt5界面开发学习