当前位置: 首页 > news >正文

Cross Encoder 架构类型

什么是 Cross Encoder 架构?

Cross Encoder 是一种用于处理文本对(text pair)任务的深度学习模型架构,常用于 Rerank(重排序)问答匹配自然语言推理(NLI) 等任务。

双塔模型(Dual Encoder) 不同,Cross Encoder 会将两个输入(通常是 Query 和 Document)拼接在一起输入模型,进行联合编码和语义交互


🧠 架构工作流程

假设我们有:

  • Query: "中国的首都是哪?"

  • 文档候选: "北京是中华人民共和国的首都。"

Cross Encoder 的处理方式是:

Input: [CLS] 中国的首都是哪? [SEP] 北京是中华人民共和国的首都。 [SEP]
→ 送入 Transformer(如 BERT、RoBERTa、DeBERTa)
→ 输出:[CLS] token 对应的向量
→ 用于分类 / 打分(如相关性得分)

🔍 Cross Encoder vs Dual Encoder

对比维度Cross EncoderDual Encoder (双塔)
输入方式拼接 query 和 document分别编码 query 和 document
交互能力强,能捕捉两者深层语义关系弱,编码过程中无交互
模型输出单个相关性得分或分类结果向量(embedding)
检索速度慢(每次需对 query-doc 对做前向传播)快(doc 向量可预计算)
应用场景精排(Rerank)、匹配任务粗排(向量检索)
精度表现✅ 高❌ 相对较低


🧪 技术细节

  1. 模型结构

    • 多采用 BERT/RoBERTa/DeBERTa 为 backbone;

    • 输出 [CLS] token 表示 query-document 整体表示,用于打分;

    • 一般加上一个简单的线性分类头(linear head)输出相关性分数或分类结果。

  2. 输入格式

    • [CLS] query tokens [SEP] document tokens [SEP]

  3. 损失函数

    • 通常使用 交叉熵(Cross Entropy)对比学习损失(Contrastive Loss) 训练;

    • 数据集示例:MS MARCO、NLI、DuReader、LCQMC 等。


✅ 优势与劣势

✅ 优势

  • 精度高,可捕捉复杂语义关系;

  • 更适合做 高质量排序任务(如知识库重排序)

  • 易于 fine-tune。

❌ 劣势

  • 推理慢,无法预计算 document 向量;

  • 不适合大规模文档检索场景(需先用 Dual Encoder 初筛);


📚 应用场景举例

  • RAG(知识增强生成)系统 中 rerank 文档片段;

  • 智能问答系统 匹配问题和答案候选;

  • 语义文本相似度计算

  • 判定文本蕴涵关系(NLI)

  • 信息抽取中的实体匹配与归一化

Cross Encoder 效率 

⏱️ Cross Encoder 的效率分析

1. 计算复杂度

Cross Encoder 的最大问题是:每个 query 和每个候选 document 都需要一次完整的前向传播计算

  • 如果你有:

    • 1 个查询 Q

    • 100 个候选文档 D1...D100

  • 则需要 100 次前向计算,每次输入是 [CLS] Q [SEP] Dn [SEP]

时间复杂度:O(N),其中 N 是候选文档数

相比之下,Dual Encoder 模型只需要一次 query 编码 + 多次 dot-product。


2. 推理耗时(延迟)

根据实际实验(例如使用 BERT-base):

模型Batch Size100 文档的平均 rerank 时间设备
BERT-base Cross Encoder8~300–500ms单张 GPU(如 3090)
RoBERTa-large4~800ms–1s单张 GPU
DeBERTa-v3-large21–2s单张 GPU

如果没有 GPU,只使用 CPU,延迟可能会更高(2–5s)。


3. 批量处理优化

你可以使用 批量推理(batch inference) 来部分缓解性能问题,例如一次输入多个 [query, doc] pair:

 

python

复制编辑

batch_inputs = [ tokenizer("Q [SEP] D1", return_tensors="pt"), tokenizer("Q [SEP] D2", return_tensors="pt"), ... ] # Stack 然后并行前向传播

这样可以通过 GPU 并行计算,减少平均单个 pair 的耗时。


4. 适合的使用策略

由于性能瓶颈,Cross Encoder 一般不用于大规模检索,而用于 精排(rerank)阶段

  • 推荐搭配:

    1. 先用 Dual Encoder 或向量搜索(如 FAISS)快速初筛 100 个候选;

    2. 然后使用 Cross Encoder 精排前 100 个结果;

    3. 最终选出 Top-N 精确文档。

这种方式叫做 “二阶段检索”(Two-Stage Retrieval)。


✅ 总结

项目表现
检索精度✅ 非常高(最强之一)
推理速度❌ 慢,随候选数量线性增长
并行化能力✅ 支持 GPU 批量处理优化
是否可离线预计算❌ 否,必须实时对每个 query-doc pair 推理
推荐使用场景小规模 rerank、高质量精排、QA系统
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np# 1. 向量检索模型(Dual Encoder)
embedding_model = SentenceTransformer("BAAI/bge-base-zh-v1.5")# 2. 精排模型(Cross Encoder)
rerank_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-base")
rerank_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-base")
rerank_model.eval()# 示例文档
docs = ["北京是中国的首都。","苹果公司是一家科技公司。","中国的首都是北京。","巴黎是法国的首都。","中国位于东亚,是世界上人口最多的国家。"
]# 1️⃣ 粗排阶段:向量化 + FAISS 检索
# 向量化所有文档
doc_embeddings = embedding_model.encode(docs, normalize_embeddings=True)# 建立 FAISS 索引
index = faiss.IndexFlatIP(doc_embeddings.shape[1])  # 使用内积(余弦相似度)
index.add(doc_embeddings)# 查询语句
query = "中国的首都是哪?"
query_embedding = embedding_model.encode([query], normalize_embeddings=True)# 搜索 Top K 候选
top_k = 3
D, I = index.search(query_embedding, top_k)# 提取候选文档
candidates = [docs[i] for i in I[0]]# 2️⃣ 精排阶段:Cross Encoder 对候选文档打分
def rerank(query, docs):inputs = [f"{query} [SEP] {doc}" for doc in docs]inputs = rerank_tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')with torch.no_grad():scores = rerank_model(**inputs).logits.squeeze(-1)# 按得分排序ranked = sorted(zip(docs, scores.tolist()), key=lambda x: x[1], reverse=True)return ranked# 执行 rerank
ranked_results = rerank(query, candidates)# 输出结果
print("查询:", query)
print("\nTop 文档:")
for i, (doc, score) in enumerate(ranked_results):print(f"{i+1}. [{score:.4f}] {doc}")

 

相关文章:

  • UART16550 IP core笔记二
  • SpringDataRedis的入门案例,以及RedisTemplate序列化实现
  • 小皮面板从未授权到RCE
  • 【pypi镜像源】使用devpi实现python镜像源代理(缓存加速,私有仓库,版本控制)
  • 基于Python的高效批量处理Splunk Session ID并写入MySQL的解决方案
  • 【人工智能-agent】--Dify中自然语言生成SQL查询数据库
  • 如何快速入门大模型?
  • 精益数据分析(55/126):双边市场模式的挑战、策略与创业阶段关联
  • o.redisson.client.handler.CommandsQueue : Exception occured. Channel
  • 【深度学习】计算机视觉(18)——从应用到设计
  • 【大模型MCP协议】MCP官方文档(Model Context Protocol)一、开始——1. 介绍
  • Java—— 集合 Set
  • 【Spark】使用Spark集群搭建-Standalone
  • 在Web应用中集成Google AI NLP服务的完整指南:从Dialogflow配置到高并发优化
  • FFmpeg 项目中的三大核心工具详解
  • 企业管理软件:数字化转型的核心引擎
  • spdlog日志器(logger)的创建方法大全
  • 从0到1:Python机器学习实战全攻略(8/10)
  • 03.Golang 切片(slice)源码分析(二、append实现)
  • 循环语句:for、range -《Go语言实战指南》
  • 商务部就开展打击战略矿产走私出口专项行动应询答记者问
  • 第一集丨《亲爱的仇敌》和《姜颂》,都有耐人寻味的“她”
  • 印巴战火LIVE丨“快速接近战争状态”?印度袭击巴军事基地,巴启动反制军事行动
  • 李在明正式登记参选下届韩国总统
  • 北外滩集团21.6亿元摘上海虹口地块,为《酱园弄》取景地
  • 王日春已任教育部社会科学司司长,此前系人教社总编辑