当前位置: 首页 > news >正文

Meta新注意力机制给 Transformer 升了级!底层架构的革命!

1.导读

随着大模型训练日益受到“高质量 token 不足”的限制,如何在有限 token 预算下实现更高效的推理与学习,成为大语言模型架构演进的关键问题。传统 Transformer 注意力机制虽然强大,但其计算本质仍局限于二元关系(dot product)的建模。本文提出了一种结构性创新:将注意力从二阶拓展为三阶(三线性)关系——即所谓的 2-simplicial attention。通过 Triton 高效实现,该机制在相同模型参数与 token 预算下,在数学、逻辑推理与代码生成任务中,优于标准 Transformer,并首次实证表明它能改变 Scaling Law 中的指数项(exponent),为构建高效 Transformer 架构提供了新的理论与实践路径。


论文基本信息

  • 论文标题Fast and Simplex: 2-Simplicial Attention in Triton
  • 作者:Aurko Roy, Timothy Chou, Sai Surya Duvvuri, Sijia Chen, Jiecao Yu, Xiaodong Wang, Manzil Zaheer, Rohan Anil
  • 单位:Meta AI, University of Texas at Austin
  • 年份:2025
  • 链接:https://arxiv.org/abs/2507.02754v1

点击阅读原文,获取更多咨询

2. 研究背景与结构性突破:从点积到三线性注意力

2.1 从 Token 扩张到 Token 高效

自 Kaplan 等人提出神经网络的 Scaling Law 以来,大规模语言模型(LLMs)的发展遵循一个黄金法则:模型参数量和训练数据规模需同步增长,才能达到计算最优(compute-optimal)。例如,Chinchilla 模型通过以更高 token 数训练,成功在较小参数规模下击败了体量更大的 Gopher。

然而,随着高质量语料资源逐步耗尽,“无限 token 假设”日益失效。模型开发正从“计算受限”(compute-bound)转向“数据受限”(data-bound)阶段。这种变化对架构设计提出了新挑战:

如何在有限 token 预算下,提升模型的表达与学习能力?

2.2 方法简介:从点积注意力到三线性交互

为回应这一挑战,本文提出了 2-simplicial Transformer,其核心创新在于:

用三线性(trilinear)函数替代传统的点积注意力(dot-product attention)机制

Alt

标准 Transformer 中,每个 query 通过点积获取 key 的相关性。而 2-simplicial 注意力引入第三个向量(第二组 key),构成更高阶的三元交互。这一机制本质上从“一维边”跃迁为“二维面”的信息建模方式,显著增强了注意力机制对复杂关系的捕捉能力

结合高效的 Triton 核函数实现,该方法不仅在理论上具有更优的 scaling 指数,还在推理、数学、代码生成等任务中展现出实际性能优势。

3. 方法原理与模型结构设计:2-Simplicial Attention

3.1 从标准注意力到三线性交互

Transformer 模型的核心机制是自注意力(Self-Attention),其基本形式建立在输入序列 X∈Rn×dX \in \mathbb{R}^{n \times d}XRn×d 的三组线性投影之上:
请添加图片描述

其中 WQ,WK,WV∈Rd×dW_Q, W_K, W_V \in \mathbb{R}^{d \times d}WQ,WK,WVRd×d 分别表示 query、key、value 的投影矩阵。标准注意力机制通过计算 query 与 key 的点积作为注意力分数:

Alt

最终每个 token 的输出是对所有 value 的加权平均:
Alt

该机制本质上是建模序列中两两 token 的相似性关系,可视为1-simplex(边)结构,其表达能力受限于二元交互。


3.2 三线性注意力:2-Simplicial 扩展

为提升注意力机制在建模复杂逻辑与结构性任务中的表达能力,本文引入 2-simplicial attention,即从二元点积拓展为三元交互的三线性张量形式

Alt

其中,K′K'K 是对输入序列的另一组投影,表示第二组 key。三线性注意力以 (qi,kj,kk′)(q_i, k_j, k'_k)(qi,kj,kk) 三元组构造注意力得分张量,并通过 softmax 转换为概率分布:

Alt

输出则为两个 value 向量之间的 Hadamard 积的加权和:

Alt

  • 与传统 attention 相比,
    2-simplicial attention 本质上是建模2-simplex(三点面)之间的交互关系,理论上拥有更强的组合能力和推理表达力。

3.3 计算优化:滑动窗口与复杂度控制

由于 2-simplicial attention 的原始复杂度为 O(n3)O(n^3)O(n3),直接在全序列上应用不具现实性。为此,作者引入滑动窗口机制(windowed 2-simplicial attention),将三元注意力限制在局部范围:

  • 每个 query qiq_iqi 只关注 w1w_1w1 个 key 和 w2w_2w2k′k'k 向量;
  • 计算复杂度降为 O(nw1w2)O(n w_1 w_2)O(nw1w2),实现 token 局部交互的高效建模;
  • 在实验中使用 (w1,w2)=(512,32)(w_1, w_2) = (512, 32)(w1,w2)=(512,32) 获得较优的延迟与性能平衡。

这种结构类似于 FlashAttention 的 Tile 机制,但在三线性张量的构造与并行调度中加入了 Triton 实现的细粒度优化。


3.4 位置编码的拓展与旋转不变性

标准 RoPE(Rotary Position Embedding)利用旋转矩阵保证点积在加上位置编码后仍保持相对位置信息。然而,对于三线性内积:
Alt

其对同构旋转不具不变性,因此需要重新设计可旋转等价的三线性形式。作者提出一种基于行列式的替代方案:
Alt

这种结构等价于三个向量构成体积张量的有符号体积,满足旋转不变性(Rotation Invariance),理论上可用于构建位置敏感且保持结构一致性的注意力机制。


3.5 模型结构集成与实现概览

最终的 2-simplicial Transformer 架构采用模块化设计:

  • 每 4 层 Transformer 中插入一层 2-simplicial 注意力层;
  • 搭配 Grouped Query Attention(GQA)以降低并行通信开销;
  • 使用 Triton 编写高性能核函数,并在 CUDA Tensor Core 上实现融合计算;
  • 支持前向与反向的高效 Tile 并行,避免原子操作瓶颈。

整体设计兼顾结构表达力与工程可扩展性,是目前实现可落地的高阶注意力结构之一。

4. 实验设计与结果分析

4.1 实验设置

为了评估 2-simplicial Transformer 的实际效果,作者设计了多个规模的 MoE(Mixture-of-Experts)模型,活跃参数规模从 1B 到 3.5B,插入方式为:每四层 Transformer 中插入一层 2-simplicial 注意力层

训练设置如下:

  • Optimizer: AdamW(学习率 4e-3,cosine decay,warmup 4000 steps)
  • Token 数量:固定
  • Triton 编写 attention 核函数,支持 forward 与 backward 高效 tile 化计算

4.2 模型性能对比

下图展示了不同参数规模下,2-simplicial Transformer 与标准 Transformer 在四个基准任务上的 NLL(负对数似然)对比:

Alt

观察:

  • 当模型参数增大时,2-simplicial Transformer 的优势越发显著;
  • 在 MMLU-pro 和 GSM8k 等复杂任务上,NLL 降低明显,说明其更擅长处理推理与逻辑任务。

4.3 Scaling Law 分析

作者进一步拟合了 Scaling Law 曲线,形式为:

请添加图片描述

Alt

α 表示参数效率指数,越大表示每单位参数带来的性能提升越强。


4.4 拟合效果验证

为了验证上述 Scaling Law 拟合的稳定性,作者报告了每个模型的 R2R^2R2 和残差(residual):

Alt

在 MMLU-pro 和 MBPP 上,2-simplicial 的拟合效果更优,表明其扩展能力与泛化表现更加平稳。


4.5 计算效率分析

此外,作者评估了 2-simplicial 注意力在不同窗口大小配置下的延迟与 FLOPs:

  • 使用 Triton 实现的核函数,在窗口设置 (512, 32) 时达到最佳计算效率;
  • 达到接近 FlashAttention v3 的算力表现(520 TFLOPS);
  • 对大序列长度(例如 48k tokens)仍保持高吞吐。

小结

  • 2-simplicial Transformer 在 token 固定条件下实现更优 scaling;
  • 在高难度任务上呈现出更强的逻辑推理与结构建模能力;
  • 并通过高效的 Triton 实现保证工程落地的可行性。

5. 总结与展望

本文围绕 2-simplicial Transformer 架构展开,系统性地探讨了其在有限 token 条件下的表达能力、结构优势与实际性能表现。与标准基于点积的注意力机制不同,2-simplicial 注意力通过引入三线性交互结构,实现了从“边”到“面”的关系建模,具备更强的组合表达力,尤其适用于逻辑推理、数学推导与代码生成等结构化任务。

在理论层面,作者通过 Scaling Law 分析发现:2-simplicial Transformer 在固定训练 token 数下,呈现出更高的 scaling 指数 α,即更强的参数效率。这一突破打破了传统观点——即模型结构改进只能带来损失偏移,而无法提升 scaling 斜率。

在工程实现上,论文提供了基于 Triton 的高效内核优化策略,使得 2-simplicial attention 在大序列长度下仍具备接近 FlashAttention v3 的吞吐性能,证明其在硬件亲和性与可扩展性上的现实潜力。

展望未来,作者指出两条值得继续探索的方向:

  • 一是针对不同硬件(如 TPU、ASIC)设计更深度优化的 2-simplicial 实现;
  • 二是将该结构迁移至下游任务,如长文本理解、多跳推理、多模态融合等复杂场景。

总而言之,2-simplicial Transformer 提供了一个兼具结构创新与工程落地的新范式,或将在后 token 时代引领下一轮高效模型设计潮流。

6. 快速实现

为了支持 2-simplicial attention 的高效训练与推理,本文在 Triton 中实现了完整的前向与反向核函数。相比标准点积注意力,该实现显著提升了三线性交互的吞吐效率,在长序列场景下可达到 FlashAttention v3 级别的计算性能。


6.1 前向传播:三线性注意力计算

以下为 2-simplicial attention 的核心前向逻辑,已在 Triton 上实现 tile 并行、数值稳定 softmax、并行积累:

@triton.jit
def two_simplicial_attn_fwd_kernel(Q, K1, K2, V1, V2, O, M, ..., BLOCK_SIZE_Q, BLOCK_SIZE_KV):# Tile Q, K1, K2, V1, V2q_tile = load_tile(Q)for kv1 in kv1_range:k1_tile = load_tile(K1)v1_tile = load_tile(V1)for kv2 in kv2_range:k2_tile = load_tile(K2)v2_tile = load_tile(V2)# Compute trilinear logits: A[i,j,k] = q * k1 * k2scores = dot(q_tile * k1_tile, k2_tile)# Apply local mask and numerically stable softmaxprobs = softmax_stable(scores)# Compute output using Hadamard(V1, V2)v12 = v1_tile * v2_tileO += probs @ v12

6.2 反向传播:两阶段梯度计算

由于三线性结构涉及三组参数,本文将反向传播拆分为两个 Triton kernel,实现无冲突的高效并行。

6.2.1 第一阶段:更新 dQ、dK1、dV1
 @triton.jit
def bwd_kv1_kernel(Q, K1, K2, V1, V2, dO, ...):for kv2 in kv2_range:k2_tile, v2_tile = load(K2), load(V2)for kv1 in kv1_range:k1_tile, v1_tile = load(K1), load(V1)probs = compute_softmax(...)dV1 += probs @ (dO * v2_tile)dK1 += (probs_grad @ Q) * k2_tiledQ  += probs_grad.T @ (k1_tile * k2_tile)
6.2.2第二阶段:更新 dK2、dV2 与补充 dQ
@triton.jit
def bwd_kv2q_kernel(Q, K1, K2, V1, V2, dO, ...):for kv1 in kv1_range:k1_tile, v1_tile = load(K1), load(V1)for kv2 in kv2_range:k2_tile, v2_tile = load(K2), load(V2)probs = compute_softmax(...)dV2 += probs @ (dO * v1_tile)dK2 += probs_grad @ (Q * k1_tile)dQ  += probs_grad.T @ (k1_tile * k2_tile)

关注下方《AI前沿速递》🚀🚀🚀
各种重磅干货,第一时间送达
码字不易,欢迎大家点赞评论收藏

http://www.dtcms.com/a/272696.html

相关文章:

  • JAVA JVM对象的创建
  • 水陆联防智能升级:AI入侵检测系统守护零死角安全
  • 介绍 cnpm exec electron-packager
  • x86汇编语言入门基础(三)汇编指令篇3 位移运算
  • 【threejs】第一人称视角之八叉树碰撞检测
  • 蜻蜓I即时通讯系统重构宣言:破茧重生的技术革命-长痛不如短痛卓伊凡|麻子|果果
  • 大健康IP如何借“合规创新”抢占行业新风口|创客匠人
  • 解读 Go 中的 constraints包
  • 【TCP/IP】7. IP 路由
  • xml 知识总结: xsd,xsi:schemaLocation,xmlns,xmlns:xsi
  • SpringBoot系列—MyBatis(xml使用)
  • codeforeces Round1032 - Round 1036
  • 【node后端】搭建项目(Express+Ts+Typeorm+Mysql一步到位)
  • 深入浅出 Python Asynchronous I/O:从 asyncio 入门到实战
  • Arc Institute提出首个AIVC虚拟细胞模型STATE
  • 上海交大医学院张维拓老师赴同济医院做R语言训练营培训
  • 从Debug中学习MiniGPT4
  • 在Vue中如何对组件进行销毁在进行挂载
  • 模型训练之数据标注-Labelme的使用教程
  • 5款工具高效制作插图,PPT设计新选择!
  • 货车车架和悬架设计cad【7张】+设计说明书
  • leetcode 3440. 重新安排会议得到最多空余时间 II 中等
  • 《PyQt6-3D:开启Python 3D编程新世界 2》
  • 【TCP/IP】8. 传输层协议
  • hive小文件问题
  • 二层环路避免-STP技术
  • Linux【大数据运维】下制作Redis绿色免安装包(一)
  • 企业网络安全的“金字塔”策略:构建全方位防护体系的核心思路
  • upload-labs靶场通关详解:第20关 /.绕过
  • 以下哪种类型在Golang中不是内置类型?