当前位置: 首页 > news >正文

论文阅读笔记——Multi-Token Attention

MTA 论文
在 Transformer 中计算注意力权重时,仅依赖单个 Q 和 K 的相似度,无法有效捕捉多标记组合信息。(对于 A、B 两个词,单标记注意力需要分别计算两个词的注意力分数,再通过后处理定位共同出现的位置或通过多层隐式堆叠,增加模型深度和容量)。MTA 显示建模多标记依赖,同时不牺牲全局交互和额外参数。(通过卷积运算让他能够看到邻近的Q、K 以及其他注意力头的信息)

在 Transformer 其他部分,如 FFN 的输入/输出加卷积,主要是为了捕捉词元表示之间的局部依赖关系,不直接改变注意力机制本身如何计算相关性
MTA 的卷积直接作用在 Q K T / A QK^T/A QKT/A,意味着卷积直接参与了决定哪些上下文位置应该被关注的过程,在处理词元间的关系强度
在这里插入图片描述
提出两种方式:pre-softmax convolution 和 post-softmax convolution,MTA 默认采用 Pre-softmax Q-K Convolution 和 Post-softmax Head Mixing Convolution。二者区别在于是在 softmax 之前还是之后进行。

Q-K convolution

a i j = S o f t m a x ( ∑ i ′ = 0 c q − 1 ∑ j ′ = − ⌊ c k / 2 ⌋ ⌈ c k / 2 ⌉ − 1 1 i ≥ j − j ′ θ i ′ , j ′ q i − i ′ k j − j ′ ⊤ / d ) ( 1 ) a_{ij}=\mathrm{Softmax}\left(\sum_{i^{\prime}=0}^{c_{q}-1}\sum_{j^{\prime}=-\lfloor c_{k}/2\rfloor}^{\lceil c_{k}/2\rceil-1}\mathbf{1}_{i\geq j- j^{\prime}}\theta_{i^{\prime},j^{\prime}}q_{i-i^{\prime}}k_{j-j^{\prime}}^{\top}/\sqrt{d}\right) \qquad \qquad(1) aij=Softmax i=0cq1j=ck/2ck/211ijjθi,jqiikjj/d (1)
在卷积中,为防止未来信息泄露,需要做 Masking。理想的 Masking 比较复杂(见式(1)),采用一种简化形式:用 0 Mask 掉未来的 Q K T QK^T QKT 值,做卷积,再用 − ∞ -\infty Mask 掉结果中非法位置,再做 Softmax。
A = S o f t m a x ( M a s k − ∞ ( C o n v 2 d θ ( M a s k 0 ( A ^ ) ) ) ) . A=\mathrm{Softmax}\left(\mathrm{Mask}_{-\infty}\left(\mathrm{Conv}2\mathrm{d}_\theta\left(\mathrm{Mask}_0(\hat{A})\right)\right)\right). A=Softmax(Mask(Conv2dθ(Mask0(A^)))).

Head Mixing Convolution

允许不同注意力头之间共享信息,放大重要信号。将 M 个头分成 M / c h M/c_h M/ch 个组,每组 c h c_h ch 个头。在每组的头内左 1D 卷积。同样可以在 softmax 之前或之后进行。

Group Normalization with depth scaling

改善梯度流,对抗深层网络中残差连接可能带来的主导效应(让模型更关注注意力部分输出,而不是仅仅传递上一层信息)。
在每个头的输出上独立应用组归一化,并结合一个随层数变化的缩放因子。

核心矛盾:在「增强注意力精度」和「保持计算效率」之间尚未找到完美平衡,当前更适合对计算资源不敏感的高精度场景。

实验结果

1.找字母块任务,验证 MTA 能够解决 [多条件匹配] 问题。
在这里插入图片描述
MTA 错误率接近 0%,而 Transformer 失败率超 50%

2.LLM,在 105B 词元数据上训练 880M 参数模型

  • MTA 仅在 1/4 的层 使用 Key-Query 卷积(核大小: c q = 6 , c k = 11 c_q=6,c_k=11 cq=6,ck=11)。
  • 所有层使用 Head 卷积(核大小 c h = 2 c_h=2 ch=2)。
    在这里插入图片描述

相关文章:

  • 华为机试—最大最小路
  • 为什么在删除数据库存在‘if exists‘语句
  • 判断两个 IP 地址是否在同一子网 C
  • Redis实现分布式定时任务
  • 畅游Diffusion数字人(23):字节最新表情+动作模仿视频生成DreamActor-M1
  • dfs和bfs算法
  • PyTorch DataLoader 参数详解
  • Autoware源码总结
  • 路由策略/策略路由之Filter-Policy
  • 基础层数据从kafka读取写入hbase的优化方案
  • 摄像头解析
  • 第一期:[特殊字符] 深入理解MyBatis[特殊字符]从JDBC到MyBatis——持久层开发的转折点[特殊字符]
  • Vue 3 和 Vue 2 的区别及优点
  • Flask+Plotly结合动态加载图形页面实践
  • cluster、update、delete在死元组清理上的作用
  • boss zp_stoken补环境
  • TQ15EG开发板教程:AD9361观测adc采集波形
  • Elasticsearch 系列专题 - 第七篇:实战项目
  • Ubuntu 22.04 完美安装 ABAQUS 教程:从零到上手,解决兼容问题
  • 数据结构(1)
  • 威海哪家网站做的好/销售网络平台
  • 个人设计师为什么做网站/总裁班课程培训
  • 建设网站专业/指数平台
  • 机关网站建设创新/seo优化排名
  • 网站开发需要做什么工作/seo优化易下拉霸屏
  • 网站设计与开发怎么做/花钱推广的网络平台