准确识别检索头,提高大模型长上下文能力
论文标题
Query-Focused Retrieval Heads Improve Long-Context Reasoning and Re-ranking
论文地址
https://arxiv.org/pdf/2506.09944
代码地址
https://github.com/princeton-pli/QRHead
作者背景
普林斯顿大学,得克萨斯大学奥斯汀分校
动机
近来大模型的长上下文能力有了明显进步,但其内部机理尚未完全明确。研究者曾通过观察模型做大海捞针任务(Needle In A Haystack)时的表现发现了专门负责上下文复制粘贴的注意力头,称为“检索头”,进而发展出一些减少幻觉、提高推理能力、优化 kv-cache 的方法
然而相比于实际应用中的长上下文场景,大海捞针任务过于简单(复述关键信息),由此识别“检索头”可能并未充分理解各注意力头的协作关系,导致识别结果不准确
于是作者希望基于更贴近于真实长上下文检索任务,设计一种更准确有效的“检索头”识别方法,进而提高下游应用的效果
检索头介绍
“大海捞针”任务是当前衡量大模型长上下文的常用方法,它要求模型在不同长度的上下文、不同的上下文位置找出指定文本,统计准确率
在这一过程中,我们可以分别考察模型每一层的每一个注意力头,观察其对上下文中每个 token 产生的注意力分数,然后记录推理过程中注意力分数最高的 token 序列,最后统计此序列与“针”文本的重叠比例作为此注意力头的检索得分,高于0.1的则被识别为“检索头”,它在长上下文推理过程中主要负责复制粘贴已有文本
作者分析发现,“检索头”具备以下性质:
- 通用性: 任何具有长上下文能力的模型都有少量的检索头,无论其架构、训练方法如何
- 稀疏性: 只有少量注意力头负责检索,其他大部分则负责理解与生成
屏蔽检索头会严重影响上下文推理,模型产生大量幻觉(但语言是通顺的),显著降低推理效果(CoT推理时,模型需要不断回顾自己生成的内容)
- 一致性: 检索头是基础模型的固有能力,起源于大规模预训练,后续的衍生模型(继续预训练、微调)都与基础模型使用同一组检索头
Retrieval Head Mechanistically Explains Long-Context Factuality
https://arxiv.org/pdf/2404.15574
本文方法
一、QRHead 识别
本文提出QRHead(Query-Focused Retrieval Head,聚焦于查询的检索头),核心思想是从具体的长上下文推理任务出发,重新设计上述“检索分数”
如上图所示,相比于在合成数据上识别注意力头的复制粘贴行为,作者直接在真实上下文检索数据上(包含候选文档、目标文档、查询问题),统计所有目标文本的注意力分数之和,作为检索得分。在后续实验中,作者在参数量少于10B的模型上选择了分数最高的16个头,尺寸更大的模型(如Llama-3.1-70B)上选择了32个头,鉴于检索头的稀疏性,这大约占总注意力头的 1-2%
值得注意的是,这些现实数据只需要标注目标文档即可(哪些长上下文片段与查询问题相关),而且仅需要少量(少于100条)数据便可有效地找出 QRHead
二、构建通用检索器
“检索头”最直接的用法是在文档召回过程中用作打分器,于是作者提出 QRRetriever:聚合所有 QRHead 对某一候选文档的检索分数,作为此文档的相关性指标
直接利用大模型中的注意力头来做检索器,省去了繁琐的训练,具有较强的通用性,并且具有较好的并行计算能力(传统的单塔、双塔架构需要在特定数据上训练,跨领域能力较差)
为了减轻语言模型中注意力权重的内在偏差,作者还使用了一种得分校准方法:使用空字符串代替query中的问题,然后计算 query 对候选文档产生的注意力分数作为基准,R(q, d) 减去此基准才是最终的相关性得分
实验结果
一、长上下文推理测试
将长上下文输入进行切分、打分、召回、拼接,再送入大模型做推理
- Full context: 将完整的上下文送入大模型做推理
- BM25: 传统基于统计的检索器
- Contriever: 双塔架构,在大规模无监督数据上做预训练
- Stella: 双塔架构,在多个检索数据集上做预训练
- QRRetriever: 基于 LongMemEval 数据集中70个单跳问题,识别大模型种的 QRHead,再基于它们的注意力分数对上下文片段做排序
测试结果如下所示,考察指标为召回准确率以及端到端问答的准确率,测试数据包括 LongMemEval(chat场景,筛选多跳问题) 和 CLIPPER(书籍问答场景)
可见 QRRetriever 明显优于其他方法,具有良好的跨场景泛化能力,并且具有明显的规模效应:基础模型尺寸越大端到端效果越好。但仔细观察上表也能发现,小尺寸模型的召回准确性可与大尺寸模型媲美,端到端效果主要受限于生成能力的不足。所以实践中我们可以先基于小模型做检索,再使用大模型做推理回答
二、段落重排序测试
除了直接筛选相关文档进行推理,许多业务场景下还需要对文档进行排序。下表展示了在 BEIR(多领域排序任务)基准上各实验组的测试结果
对于参数少于 10B 的模 型,QRRetriever 始终优于其他基准;对于更大的 Llama-3.1-70B 模型,QRRetriever 显著超越了 ICR,落后于 RankGPT_Bubble,但成本明显更低(后者需要超过 200 次生成调用)
三、屏蔽不同注意力头的影响
下图展示了屏蔽不同注意力头时,大海捞针任务的影响。可见本文提出的QRHead对于检索任务起到更关键的作用,因为当屏蔽16个头时,QRHead组便展现出了明显的性能下降,而原始检索头组的变化不大
此外,作者还对比了 QRHead 与原始检索头之间的差异,在 top-32 和 top-64 的注意力头中,两种方法分别只有 8 个和 32 个重叠,这凸显了二者的独特性,表明 QRHead 并非只是检索头的简单改进
四、长度泛化能力测试
为了进一步验证本文方法的通用性,作者还测试了在相对更短任务上识别 QRHead,在更长的上下文上测试,结果表明 QRRetriever 对上下文长度变化具有较好的鲁棒性
五、检索头对数据变化的敏感性
为了观察 QRHead 是否与前序工作中的“检索头”一样,具备良好的一致性,作者从 NQ 数据集中抽取了三组互不重叠的样本,用其识别 QRHead 并观察每组识别结果的重叠情况,结果如下图所示
可见模型的 QRHead 也具有高度的一致性,每组识别结果中有超过 50/64 = 78% 的 QRHead 是相同的