GQA:从多头检查点训练广义多查询Transformer模型
摘要
多查询注意力(MQA)仅使用单个键-值头,能大幅加速解码器推理。然而,MQA可能导致质量下降,而且专门为更快的推理训练单独的模型可能并不可取。我们:(1) 提出了一种从现有多头语言模型检查点Uptraining具有MQA的模型的方案,仅需原始预训练计算量的5%;(2) 引入了分组查询注意力(GQA),这是多查询注意力的一种泛化,它使用中间数量(多于一个但少于查询头数量)的键-值头。我们表明,Uptraining的GQA在接近多头注意力质量的同时,速度与MQA相当。
1 引言
自回归解码器推理是Transformer模型的一个严重瓶颈,因为在每个解码步骤中都需要加载解码器权重以及所有注意力键和值,这带来了内存带宽开销(Shazeer, 2019; Pope et al., 2022; de Jong et al., 2022)。通过多查询注意力(Shazeer, 2019)可以显著减少加载键和值的内存带宽,该方法使用多个查询头但仅使用单个键和值头。
然而,多查询注意力(MQA)可能导致质量下降和训练不稳定,并且训练专门针对质量和推理优化的单独模型可能不可行。此外,虽然一些语言模型已经使用多查询注意力,如PaLM(Chowdhery et al., 2022),但许多模型并未使用,包括公开可用的语言模型如T5(Raffel et al., 2020)和LLaMA(Touvron et al., 2023)。
本工作包含两个贡献,用于加速大型语言模型的推理。首先,我们证明具有多头注意力(MHA)的语言模型检查点可以通过原始训练计算量的一小部分Uptraining为使用MQA的模型。这为获得快速多查询以及高质量MHA检查点提供了一种成本效益高的方法。
其次,我们提出了分组查询注意力(GQA),这是多头和多查询注意力之间的插值,每个查询头子组共享单个键和值头。我们表明,Uptraining的GQA在接近多头注意力质量的同时,速度几乎与多查询注意力一样快。
2 方法
2.1 Uptraining
从多头模型生成多查询模型分为两个步骤:首先,转换检查点;其次,额外的预训练以允许模型适应其新结构。图1展示了将多头检查点转换为多查询检查点的过程。键和值头的投影矩阵被平均池化为单个投影矩阵,我们发现这比选择单个键和值头或从头随机初始化新的键和值头效果更好。
转换后的检查点然后在原始训练步骤的一小部分比例α\alphaα上,使用相同的预训练方案进行预训练。
2.2 分组查询注意力
分组查询注意力将查询头分为GGG组,每组共享单个键头和值头。GQA-GGG表示具有GGG组的分组查询。GQA-1(具有单个组,因此单个键和值头)等同于MQA,而GQA-HHH(组数等于头数)等同于MHA。图2展示了分组查询注意力与多头/多查询注意力的比较。当将多头检查点转换为GQA检查点时,我们通过平均池化该组内所有原始头来构造每个组的键和值头。
中间数量的组导致一个插值模型,其质量高于MQA但速度快于MHA,并且如我们将展示的,代表了一个有利的权衡。从MHA到MQA将HHH个键和值头减少为单个键和值头,将键-值缓存的大小减少了一个因子HHH,因此需要加载的数据量也相应减少。然而,较大的模型通常会扩展头的数量,使得多查询注意力在内存带宽和容量方面都进行了更激进的削减。GQA使我们能够在模型尺寸增加时保持相同的比例带宽和容量减少。
此外,较大的模型受注意力带来的内存带宽开销影响相对较小,因为KV缓存随模型维度缩放,而模型FLOPs和参数随模型维度的平方缩放。最后,大型模型的标准分片会将单个键和值头复制到模型分区数量(Pope et al., 2022);GQA消除了这种分片带来的浪费。因此,我们预计GQA对于大型模型尤其能提供良好的权衡。
我们注意到GQA不应用于编码器自注意力层;编码器表示是并行计算的,因此内存带宽通常不是主要瓶颈。
3 实验
3.1 实验设置
配置 所有模型都基于T5.1.1架构(Raffel et al., 2020),使用JAX(Bradbury et al., 2018)、Flax(Heek et al., 2020)和Flaxformer实现。对于我们的主要实验,我们考虑具有多头注意力的T5 Large和XXL,以及T5 XXL的Uptraining版本,具有多查询和分组查询注意力。我们使用Adafactor优化器,超参数和学习率计划与T5(Raffel et al., 2020)相同。我们将MQA和GQA应用于解码器自注意力和交叉注意力,但不应用于编码器自注意力。
Uptraining Uptraining模型从公开的T5.1.1检查点初始化。键和值头被平均池化为适当的MQA或GQA结构,然后使用原始预训练设置和数据集(Raffel et al., 2020)进一步预训练原始预训练步骤的α\alphaα比例。对于α=0.05\alpha=0.05α=0.05,训练大约需要600个TPUv3芯片日。
数据 我们在摘要数据集CNN/Daily Mail(Nallapati et al., 2016)、arXiv和PubMed(Cohan et al., 2018)、MediaSum(Zhu et al., 2021)和Multi-News(Fabbri et al. 2019)上进行评估;翻译数据集WMT 2014英语到德语;以及问答数据集TriviaQA(Joshi et al., 2017)。我们不在GLUE(Wang et al., 2019)等流行的分类基准上进行评估,因为自回归推理对于这些任务不太适用。
微调 对于微调,我们对所有任务使用0.001的恒定学习率、128的批大小和0.1的dropout率。CNN/Daily Mail和WMT使用512的输入长度和256的输出长度。其他摘要数据集使用2048的输入长度和512的输出长度。最后,TriviaQA使用2048的输入长度和32的输出长度。我们训练直至收敛,并选择具有最高开发集性能的检查点。我们使用贪婪解码进行推理。
计时 我们报告每个样本每TPUv4芯片的时间,由xprof(Google, 2020)测量。对于计时实验,我们使用8个TPU,最大批大小适合每个TPU最多32个,并为每个模型单独优化并行化。
3.2 主要结果
图3显示了MHA T5-Large和T5-XXL以及Uptraining的MQA和GQA-8 XXL模型(Uptraining比例α=0.05\alpha=0.05α=0.05)在所有数据集上的平均性能与平均推理时间的函数关系。我们看到,更大的UptrainingMQA模型相对于MHA模型提供了有利的权衡,在质量和推理速度上都优于MHA-Large。此外,GQA实现了显著的额外质量提升,在速度接近MQA的同时达到了接近MHA-XXL的性能。表1包含所有数据集的完整结果。
3.3 消融实验
本节介绍调查不同建模选择影响的实验。我们在代表性任务子集上评估性能:CNN/Daily Mail(短形式摘要)、MultiNews(长形式摘要)和TriviaQA(问答)。
检查点转换 图4比较了不同检查点转换方法的性能。平均池化似乎效果最好,其次是选择单个头,然后是随机初始化。直观上,结果按从预训练模型中保留信息的程度排序。
Uptraining步骤 图5显示了T5-XXL与MQA和GQA的Uptraining比例如何影响性能。首先,我们注意到GQA在转换后已经达到了合理的性能,而MQA需要Uptraining才能有用。MQA和GQA在5%的Uptraining中都获得了收益,10%后收益递减。
组数 图6展示了GQA组数对推理速度的影响。对于较大的模型,KV缓存带来的内存带宽开销不那么具有约束性(Shazeer, 2019),而键-值大小的减少由于头数增加而更为显著。因此,从MQA增加组数最初只会导致适度的减速,随着我们接近MHA,成本逐渐增加。我们选择了8组作为有利的中间点。
4 相关工作
本工作专注于通过减少加载键和值带来的内存带宽开销(Williams et al., 2009)来实现解码器质量和推理时间之间的更好权衡。Shazeer(2019)首次提出通过多查询注意力来减少这种开销。后续工作表明,多查询注意力对于长输入特别有帮助(Pope et al., 2022; de Jong et al., 2022)。Rabe(2023)独立开发了GQA并公开实现。其他工作探索了为计算效率而分组注意力头(Park et al., 2020; Luo et al., 2022; Ni et al., 2023),但没有特别关注决定内存带宽开销的键-值头。
已经提出了许多其他方法来减少来自键和值的内存带宽开销以及参数。Flash attention(Dao et al., 2022)通过结构化注意力计算来避免实例化二次方注意力分数,减少内存并加速训练。量化(Dettmers et al., 2022; Frantar et al., 2022)通过降低精度来减小权重和激活(包括键和值)的大小。模型蒸馏(Hinton et al., 2015; Gou et al., 2021)则在给定精度下减小模型大小,使用从较大模型生成的数据来微调较小模型。层稀疏交叉注意力(de Jong et al., 2022)消除了大多数交叉注意力层,这些层是长输入的主要开销。推测性采样(Chen et al., 2023; Leviathan et al., 2022)通过用较小模型提出多个标记然后由较大模型并行评分来缓解内存带宽瓶颈。
最后,我们提出的Uptraining过程受到Komatsuzaki et al.(2022)的启发,该研究将标准T5检查点Uptraining为稀疏激活的专家混合模型。
5 结论
语言模型推理昂贵主要是由于加载键和值带来的内存带宽开销。多查询注意力以降低模型容量和质量为代价减少了这种开销。我们提出将多头注意力模型转换为多查询模型,仅需原始预训练计算的一小部分。此外,我们引入了分组查询注意力,这是多查询和多头注意力的插值,它在接近多头的速度下实现了接近多头的性能。
局限性
本文专注于缓解加载键和值带来的内存带宽开销。当生成较长序列时,这种开销最为重要,而较长序列的质量本质上难以评估。对于摘要,我们使用Rouge分数,我们知道这是一种有缺陷的评估,不能讲述整个故事;因此,很难确定我们的权衡是否正确。由于计算资源有限,我们也没有将我们的XXL GQA模型与从头训练的比较模型进行比较,所以我们不知道Uptraining与从头训练的相对性能。最后,我们仅在编码器-解码器模型上评估了Uptraining和GQA的影响。最近,仅解码器模型非常流行,由于这些模型没有单独的自注意力和交叉注意力,我们预计GQA相对于MQA将具有更强的优势。
A 训练稳定性
我们发现多查询注意力在微调期间可能导致训练不稳定,特别是与长输入任务结合时。我们从头训练了多个具有多查询注意力的T5-Large模型。在每种情况下,预训练都遭受频繁的损失峰值,最终模型在长输入任务上微调时立即发散。Uptraining的多查询注意力模型更稳定,但在不稳定任务上仍显示高方差,因此对于不稳定任务上的多查询模型,我们报告三次微调运行的平均性能。然而,Uptraining的分组查询注意力模型似乎很稳定,因此我们没有进一步调查多查询不稳定的根本原因。