Cross Encoder 架构类型
什么是 Cross Encoder 架构?
Cross Encoder 是一种用于处理文本对(text pair)任务的深度学习模型架构,常用于 Rerank(重排序)、问答匹配、自然语言推理(NLI) 等任务。
与 双塔模型(Dual Encoder) 不同,Cross Encoder 会将两个输入(通常是 Query 和 Document)拼接在一起输入模型,进行联合编码和语义交互。
🧠 架构工作流程
假设我们有:
-
Query:
"中国的首都是哪?"
-
文档候选:
"北京是中华人民共和国的首都。"
Cross Encoder 的处理方式是:
Input: [CLS] 中国的首都是哪? [SEP] 北京是中华人民共和国的首都。 [SEP]
→ 送入 Transformer(如 BERT、RoBERTa、DeBERTa)
→ 输出:[CLS] token 对应的向量
→ 用于分类 / 打分(如相关性得分)
🔍 Cross Encoder vs Dual Encoder
对比维度 | Cross Encoder | Dual Encoder (双塔) |
---|---|---|
输入方式 | 拼接 query 和 document | 分别编码 query 和 document |
交互能力 | 强,能捕捉两者深层语义关系 | 弱,编码过程中无交互 |
模型输出 | 单个相关性得分或分类结果 | 向量(embedding) |
检索速度 | 慢(每次需对 query-doc 对做前向传播) | 快(doc 向量可预计算) |
应用场景 | 精排(Rerank)、匹配任务 | 粗排(向量检索) |
精度表现 | ✅ 高 | ❌ 相对较低 |
🧪 技术细节
-
模型结构
-
多采用 BERT/RoBERTa/DeBERTa 为 backbone;
-
输出
[CLS]
token 表示 query-document 整体表示,用于打分; -
一般加上一个简单的线性分类头(linear head)输出相关性分数或分类结果。
-
-
输入格式
-
[CLS] query tokens [SEP] document tokens [SEP]
-
-
损失函数
-
通常使用 交叉熵(Cross Entropy) 或 对比学习损失(Contrastive Loss) 训练;
-
数据集示例:MS MARCO、NLI、DuReader、LCQMC 等。
-
✅ 优势与劣势
✅ 优势
-
精度高,可捕捉复杂语义关系;
-
更适合做 高质量排序任务(如知识库重排序);
-
易于 fine-tune。
❌ 劣势
-
推理慢,无法预计算 document 向量;
-
不适合大规模文档检索场景(需先用 Dual Encoder 初筛);
📚 应用场景举例
-
RAG(知识增强生成)系统 中 rerank 文档片段;
-
智能问答系统 匹配问题和答案候选;
-
语义文本相似度计算;
-
判定文本蕴涵关系(NLI);
-
信息抽取中的实体匹配与归一化
Cross Encoder 效率
⏱️ Cross Encoder 的效率分析
1. 计算复杂度
Cross Encoder 的最大问题是:每个 query 和每个候选 document 都需要一次完整的前向传播计算。
-
如果你有:
-
1 个查询
Q
-
100 个候选文档
D1...D100
-
-
则需要 100 次前向计算,每次输入是
[CLS] Q [SEP] Dn [SEP]
。
时间复杂度:O(N)
,其中 N 是候选文档数
相比之下,Dual Encoder 模型只需要一次 query 编码 + 多次 dot-product。
2. 推理耗时(延迟)
根据实际实验(例如使用 BERT-base):
模型 | Batch Size | 100 文档的平均 rerank 时间 | 设备 |
---|---|---|---|
BERT-base Cross Encoder | 8 | ~300–500ms | 单张 GPU(如 3090) |
RoBERTa-large | 4 | ~800ms–1s | 单张 GPU |
DeBERTa-v3-large | 2 | 1–2s | 单张 GPU |
如果没有 GPU,只使用 CPU,延迟可能会更高(2–5s)。
3. 批量处理优化
你可以使用 批量推理(batch inference) 来部分缓解性能问题,例如一次输入多个 [query, doc]
pair:
python
复制编辑
batch_inputs = [ tokenizer("Q [SEP] D1", return_tensors="pt"), tokenizer("Q [SEP] D2", return_tensors="pt"), ... ] # Stack 然后并行前向传播
这样可以通过 GPU 并行计算,减少平均单个 pair 的耗时。
4. 适合的使用策略
由于性能瓶颈,Cross Encoder 一般不用于大规模检索,而用于 精排(rerank)阶段:
-
推荐搭配:
-
先用 Dual Encoder 或向量搜索(如 FAISS)快速初筛 100 个候选;
-
然后使用 Cross Encoder 精排前 100 个结果;
-
最终选出 Top-N 精确文档。
-
这种方式叫做 “二阶段检索”(Two-Stage Retrieval)。
✅ 总结
项目 | 表现 |
---|---|
检索精度 | ✅ 非常高(最强之一) |
推理速度 | ❌ 慢,随候选数量线性增长 |
并行化能力 | ✅ 支持 GPU 批量处理优化 |
是否可离线预计算 | ❌ 否,必须实时对每个 query-doc pair 推理 |
推荐使用场景 | 小规模 rerank、高质量精排、QA系统 |
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np# 1. 向量检索模型(Dual Encoder)
embedding_model = SentenceTransformer("BAAI/bge-base-zh-v1.5")# 2. 精排模型(Cross Encoder)
rerank_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-base")
rerank_model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-base")
rerank_model.eval()# 示例文档
docs = ["北京是中国的首都。","苹果公司是一家科技公司。","中国的首都是北京。","巴黎是法国的首都。","中国位于东亚,是世界上人口最多的国家。"
]# 1️⃣ 粗排阶段:向量化 + FAISS 检索
# 向量化所有文档
doc_embeddings = embedding_model.encode(docs, normalize_embeddings=True)# 建立 FAISS 索引
index = faiss.IndexFlatIP(doc_embeddings.shape[1]) # 使用内积(余弦相似度)
index.add(doc_embeddings)# 查询语句
query = "中国的首都是哪?"
query_embedding = embedding_model.encode([query], normalize_embeddings=True)# 搜索 Top K 候选
top_k = 3
D, I = index.search(query_embedding, top_k)# 提取候选文档
candidates = [docs[i] for i in I[0]]# 2️⃣ 精排阶段:Cross Encoder 对候选文档打分
def rerank(query, docs):inputs = [f"{query} [SEP] {doc}" for doc in docs]inputs = rerank_tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')with torch.no_grad():scores = rerank_model(**inputs).logits.squeeze(-1)# 按得分排序ranked = sorted(zip(docs, scores.tolist()), key=lambda x: x[1], reverse=True)return ranked# 执行 rerank
ranked_results = rerank(query, candidates)# 输出结果
print("查询:", query)
print("\nTop 文档:")
for i, (doc, score) in enumerate(ranked_results):print(f"{i+1}. [{score:.4f}] {doc}")