当前位置: 首页 > news >正文

【RAG排序】rag排序代码示例-高级版

以下是利用claude生成的排序示例,相对来说高级一些,例如使用了图排序、混合排序、mmr等技术。

代码是示例代码,受输出长度限制,无法给出完整例子,在最后对输入的query、document_embedding等进行了实例展示。可以参考“使用案例解释”尝试进行修改和运行。

RAG系统排序阶段的多种方法与实现

1. 基础排序方法

1.1 余弦相似度排序

最基本的相似度计算方法,适用于向量检索后的重排序。

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import torch
import torch.nn.functional as Fdef cosine_similarity_ranking(query_embedding, doc_embeddings, documents, top_k=10):"""基于余弦相似度的文档排序"""similarities = cosine_similarity(query_embedding.reshape(1, -1), doc_embeddings)[0]# 获取排序索引sorted_indices = np.argsort(similarities)[::-1][:top_k]ranked_docs = []for idx in sorted_indices:ranked_docs.append({'document': documents[idx],'score': similarities[idx],'rank': len(ranked_docs) + 1})return ranked_docs# PyTorch版本
def cosine_similarity_torch(query_emb, doc_embs):"""使用PyTorch计算余弦相似度"""query_emb = F.normalize(query_emb, p=2, dim=-1)doc_embs = F.normalize(doc_embs, p=2, dim=-1)similarities = torch.mm(query_emb.unsqueeze(0), doc_embs.T)return similarities.squeeze()

1.2 欧几里得距离排序

def euclidean_distance_ranking(query_embedding, doc_embeddings, documents, top_k=10):"""基于欧几里得距离的文档排序(距离越小,相关性越高)"""distances = np.linalg.norm(doc_embeddings - query_embedding, axis=1)# 距离越小越相关,所以升序排列sorted_indices = np.argsort(distances)[:top_k]ranked_docs = []for idx in sorted_indices:ranked_docs.append({'document': documents[idx],'distance': distances[idx],'score': 1 / (1 + distances[idx]),  # 转换为相似度分数'rank': len(ranked_docs) + 1})return ranked_docs

2. 基于深度学习的重排序模型

2.1 Cross-Encoder重排序

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torchclass CrossEncoderRanker:def __init__(self, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):self.tokenizer = AutoTokenizer.from_pretrained(model_name)self.model = AutoModelForSequenceClassification.from_pretrained(model_name)self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.model.to(self.device)self.model.eval()def rank_documents(self, query, documents, top_k=10):"""使用Cross-Encoder对文档进行重排序"""query_doc_pairs = [(query, doc) for doc in documents]# 分批处理以避免内存问题batch_size = 32scores = []for i in range(0, len(query_doc_pairs), batch_size):batch_pairs = query_doc_pairs[i:i+batch_size]# 编码输入inputs = self.tokenizer(batch_pairs,padding=True,truncation=True,max_length=512,return_tensors="pt").to(self.device)with torch.no_grad():outputs = self.model(**inputs)batch_scores = torch.softmax(outputs.logits, dim=-1)[:, 1].cpu().numpy()scores.extend(batch_scores)# 排序并返回top_kscored_docs = list(zip(documents, scores))scored_docs.sort(key=lambda x: x[1], reverse=True)return [{'document': doc,'score': score,'rank': i + 1}for i, (doc, score) in enumerate(scored_docs[:top_k])]

2.2 Bi-Encoder重排序

from sentence_transformers import SentenceTransformer
import numpy as npclass BiEncoderRanker:def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):self.model = SentenceTransformer(model_name)def rank_documents(self, query, documents, top_k=10):"""使用Bi-Encoder对文档进行重排序"""# 编码查询和文档query_embedding = self.model.encode([query], convert_to_tensor=True)doc_embeddings = self.model.encode(documents, convert_to_tensor=True)# 计算相似度similarities = torch.cosine_similarity(query_embedding.unsqueeze(1), doc_embeddings.unsqueeze(0), dim=2).squeeze().cpu().numpy()# 排序sorted_indices = np.argsort(similarities)[::-1][:top_k]return [{'document': documents[idx],'score': similarities[idx],'rank': i + 1}for i, idx in enumerate(sorted_indices)]

3. 混合排序方法

3.1 多因子加权排序

class MultiFactorRanker:def __init__(self, weights=None):self.weights = weights or {'semantic_similarity': 0.4,'keyword_match': 0.3,'document_length': 0.1,'freshness': 0.1,'authority': 0.1}def calculate_keyword_score(self, query, document):"""计算关键词匹配分数"""query_words = set(query.lower().split())doc_words = set(document.lower().split())intersection = query_words.intersection(doc_words)union = query_words.union(doc_words)return len(intersection) / len(union) if union else 0def calculate_length_score(self, document, optimal_length=500):"""计算文档长度分数"""length = len(document.split())return 1 / (1 + abs(length - optimal_length) / optimal_length)def calculate_freshness_score(self, timestamp, current_time):"""计算时效性分数"""age_days = (current_time - timestamp).daysreturn 1 / (1 + age_days / 30)  # 30天为半衰期def rank_documents(self, query, documents, doc_embeddings=None, query_embedding=None, metadata=None, top_k=10):"""综合多个因子进行排序"""scores = []for i, doc in enumerate(documents):doc_score = 0# 语义相似度if doc_embeddings is not None and query_embedding is not None:semantic_sim = cosine_similarity(query_embedding.reshape(1, -1),doc_embeddings[i].reshape(1, -1))[0][0]doc_score += self.weights['semantic_similarity'] * semantic_sim# 关键词匹配keyword_score = self.calculate_keyword_score(query, doc)doc_score += self.weights['keyword_match'] * keyword_score# 文档长度length_score = self.calculate_length_score(doc)doc_score += self.weights['document_length'] * length_score# 时效性和权威性(如果有元数据)if metadata and i < len(metadata):if 'timestamp' in metadata[i]:freshness_score = self.calculate_freshness_score(metadata[i]['timestamp'], metadata[i].get('current_time', datetime.now()))doc_score += self.weights['freshness'] * freshness_scoreif 'authority_score' in metadata[i]:doc_score += self.weights['authority'] * metadata[i]['authority_score']scores.append((doc, doc_score, i))# 排序并返回top_kscores.sort(key=lambda x: x[1], reverse=True)return [{'document': doc,'score': score,'original_index': idx,'rank': i + 1}for i, (doc, score, idx) in enumerate(scores[:top_k])]

3.2 学习到排序(Learning to Rank)

import lightgbm as lgb
from sklearn.model_selection import train_test_split
import pandas as pdclass LearningToRankModel:def __init__(self):self.model = Noneself.feature_names = ['cosine_similarity','keyword_match_score','bm25_score','document_length','query_length','common_words_ratio','edit_distance_norm']def extract_features(self, query, document, query_emb=None, doc_emb=None):"""提取特征"""features = {}# 余弦相似度if query_emb is not None and doc_emb is not None:features['cosine_similarity'] = cosine_similarity(query_emb.reshape(1, -1), doc_emb.reshape(1, -1))[0][0]else:features['cosine_similarity'] = 0# 关键词匹配分数query_words = set(query.lower().split())doc_words = set(document.lower().split())intersection = query_words.intersection(doc_words)features['keyword_match_score'] = len(intersection) / len(query_words) if query_words else 0# BM25分数(简化版)features['bm25_score'] = self.simple_bm25(query, document)# 文档和查询长度features['document_length'] = len(document.split())features['query_length'] = len(query.split())# 公共词比例features['common_words_ratio'] = len(intersection) / len(query_words.union(doc_words)) if query_words.union(doc_words) else 0# 编辑距离(归一化)from difflib import SequenceMatcherfeatures['edit_distance_norm'] = SequenceMatcher(None, query.lower(), document.lower()[:len(query)*2]).ratio()return featuresdef simple_bm25(self, query, document, k1=1.2, b=0.75):"""简化的BM25计算"""query_terms = query.lower().split()doc_terms = document.lower().split()doc_length = len(doc_terms)avg_doc_length = 100  # 假设平均文档长度score = 0for term in query_terms:tf = doc_terms.count(term)if tf > 0:idf = 1  # 简化,实际应该计算逆文档频率score += idf * (tf * (k1 + 1)) / (tf + k1 * (1 - b + b * (doc_length / avg_doc_length)))return scoredef train(self, training_data):"""训练排序模型training_data: List of (query, documents, relevance_scores)"""features_list = []labels_list = []groups = []for query, documents, relevance_scores in training_data:group_size = len(documents)groups.append(group_size)for doc, relevance in zip(documents, relevance_scores):features = self.extract_features(query, doc)features_list.append(list(features.values()))labels_list.append(relevance)# 创建LightGBM数据集train_data = lgb.Dataset(features_list, label=labels_list, group=groups,feature_name=self.feature_names)# 训练模型params = {'objective': 'lambdarank','metric': 'ndcg','ndcg_eval_at': [1, 3, 5, 10],'num_leaves': 31,'learning_rate': 0.05,'feature_fraction': 0.9}self.model = lgb.train(params,train_data,num_boost_round=100,valid_sets=[train_data],callbacks=[lgb.early_stopping(10)])def rank_documents(self, query, documents, top_k=10):"""使用训练好的模型对文档排序"""if self.model is None:raise ValueError("Model not trained yet!")features_list = []for doc in documents:features = self.extract_features(query, doc)features_list.append(list(features.values()))scores = self.model.predict(features_list)# 排序scored_docs = list(zip(documents, scores))scored_docs.sort(key=lambda x: x[1], reverse=True)return [{'document': doc,'score': score,'rank': i + 1}for i, (doc, score) in enumerate(scored_docs[:top_k])]

4. 高级排序策略

4.1 多样性排序(MMR - Maximal Marginal Relevance)

def maximal_marginal_relevance(query_embedding, doc_embeddings, documents, lambda_param=0.7, top_k=10):"""最大边际相关性排序,平衡相关性和多样性lambda_param: 控制相关性和多样性的权重"""if len(documents) == 0:return []# 计算与查询的相似度query_similarities = cosine_similarity(query_embedding.reshape(1, -1), doc_embeddings)[0]selected = []remaining_indices = list(range(len(documents)))# 选择第一个最相关的文档first_idx = np.argmax(query_similarities)selected.append(first_idx)remaining_indices.remove(first_idx)# 迭代选择剩余文档while len(selected) < top_k and remaining_indices:mmr_scores = []for idx in remaining_indices:# 相关性分数relevance_score = query_similarities[idx]# 与已选择文档的最大相似度if selected:selected_embeddings = doc_embeddings[selected]current_embedding = doc_embeddings[idx].reshape(1, -1)similarities_to_selected = cosine_similarity(current_embedding, selected_embeddings)[0]max_similarity = np.max(similarities_to_selected)else:max_similarity = 0# MMR分数mmr_score = (lambda_param * relevance_score - (1 - lambda_param) * max_similarity)mmr_scores.append((idx, mmr_score))# 选择MMR分数最高的文档best_idx, best_score = max(mmr_scores, key=lambda x: x[1])selected.append(best_idx)remaining_indices.remove(best_idx)# 构建结果result = []for i, idx in enumerate(selected):result.append({'document': documents[idx],'relevance_score': query_similarities[idx],'rank': i + 1})return result

4.2 基于图的排序(PageRank风格)

import networkx as nxclass GraphBasedRanker:def __init__(self, similarity_threshold=0.5):self.similarity_threshold = similarity_thresholddef build_similarity_graph(self, doc_embeddings, documents):"""构建文档相似度图"""G = nx.Graph()# 添加节点for i, doc in enumerate(documents):G.add_node(i, document=doc)# 添加边(基于相似度)n_docs = len(documents)similarities = cosine_similarity(doc_embeddings)for i in range(n_docs):for j in range(i + 1, n_docs):similarity = similarities[i][j]if similarity > self.similarity_threshold:G.add_edge(i, j, weight=similarity)return Gdef rank_documents(self, query_embedding, doc_embeddings, documents, alpha=0.85, max_iter=100, top_k=10):"""使用类似PageRank的算法对文档排序"""# 构建相似度图G = self.build_similarity_graph(doc_embeddings, documents)# 计算与查询的相似度作为个性化向量query_similarities = cosine_similarity(query_embedding.reshape(1, -1), doc_embeddings)[0]# 归一化个性化向量personalization = {}total_sim = np.sum(query_similarities)for i, sim in enumerate(query_similarities):personalization[i] = sim / total_sim if total_sim > 0 else 1/len(documents)# 运行个性化PageRanktry:pagerank_scores = nx.pagerank(G, alpha=alpha, personalization=personalization,max_iter=max_iter)except:# 如果图不连通,回退到基础相似度排序pagerank_scores = {i: sim for i, sim in enumerate(query_similarities)}# 排序并返回结果sorted_docs = sorted(pagerank_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]result = []for rank, (doc_idx, score) in enumerate(sorted_docs):result.append({'document': documents[doc_idx],'pagerank_score': score,'query_similarity': query_similarities[doc_idx],'rank': rank + 1})return result

5. 实时自适应排序

5.1 基于用户反馈的动态排序

class AdaptiveRanker:def __init__(self, learning_rate=0.1):self.learning_rate = learning_rateself.user_preferences = {}self.click_through_rates = {}self.feature_weights = {'semantic_similarity': 0.4,'keyword_match': 0.3,'user_preference': 0.2,'ctr': 0.1}def update_user_preference(self, user_id, query, clicked_docs, shown_docs):"""根据用户点击行为更新偏好"""if user_id not in self.user_preferences:self.user_preferences[user_id] = {}# 更新点击率for doc in shown_docs:doc_key = hash(doc)if doc_key not in self.click_through_rates:self.click_through_rates[doc_key] = {'clicks': 0, 'shows': 0}self.click_through_rates[doc_key]['shows'] += 1if doc in clicked_docs:self.click_through_rates[doc_key]['clicks'] += 1# 更新用户偏好(简化版本)query_key = hash(query)if query_key not in self.user_preferences[user_id]:self.user_preferences[user_id][query_key] = {}for doc in clicked_docs:doc_key = hash(doc)if doc_key not in self.user_preferences[user_id][query_key]:self.user_preferences[user_id][query_key][doc_key] = 0self.user_preferences[user_id][query_key][doc_key] += self.learning_ratedef get_user_preference_score(self, user_id, query, document):"""获取用户偏好分数"""if user_id not in self.user_preferences:return 0query_key = hash(query)doc_key = hash(document)return self.user_preferences[user_id].get(query_key, {}).get(doc_key, 0)def get_ctr_score(self, document):"""获取点击率分数"""doc_key = hash(document)if doc_key not in self.click_through_rates:return 0stats = self.click_through_rates[doc_key]if stats['shows'] == 0:return 0return stats['clicks'] / stats['shows']def rank_documents(self, query, documents, query_embedding=None, doc_embeddings=None, user_id=None, top_k=10):"""自适应文档排序"""scores = []for i, doc in enumerate(documents):score = 0# 语义相似度if query_embedding is not None and doc_embeddings is not None:semantic_sim = cosine_similarity(query_embedding.reshape(1, -1),doc_embeddings[i].reshape(1, -1))[0][0]score += self.feature_weights['semantic_similarity'] * semantic_sim# 关键词匹配query_words = set(query.lower().split())doc_words = set(doc.lower().split())keyword_score = len(query_words.intersection(doc_words)) / len(query_words) if query_words else 0score += self.feature_weights['keyword_match'] * keyword_score# 用户偏好if user_id:user_pref_score = self.get_user_preference_score(user_id, query, doc)score += self.feature_weights['user_preference'] * user_pref_score# 点击率ctr_score = self.get_ctr_score(doc)score += self.feature_weights['ctr'] * ctr_scorescores.append((doc, score))# 排序scores.sort(key=lambda x: x[1], reverse=True)return [{'document': doc,'score': score,'rank': i + 1}for i, (doc, score) in enumerate(scores[:top_k])]

6. 集成排序框架

6.1 多模型集成排序器

class EnsembleRanker:def __init__(self):self.rankers = {}self.weights = {}def add_ranker(self, name, ranker, weight=1.0):"""添加排序器"""self.rankers[name] = rankerself.weights[name] = weightdef rank_documents(self, query, documents, top_k=10, **kwargs):"""集成多个排序器的结果"""all_rankings = {}# 获取每个排序器的结果for name, ranker in self.rankers.items():try:rankings = ranker.rank_documents(query, documents, top_k=len(documents), **kwargs)all_rankings[name] = rankingsexcept Exception as e:print(f"Ranker {name} failed: {e}")continueif not all_rankings:return []# 计算加权平均分数doc_scores = {}for doc in documents:doc_scores[doc] = 0total_weight = 0for name, rankings in all_rankings.items():weight = self.weights[name]# 找到该文档在当前排序中的分数doc_score = 0for rank_info in rankings:if rank_info['document'] == doc:doc_score = rank_info.get('score', 0)breakdoc_scores[doc] += weight * doc_scoretotal_weight += weightif total_weight > 0:doc_scores[doc] /= total_weight# 排序并返回结果sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)return [{'document': doc,'ensemble_score': score,'rank': i + 1}for i, (doc, score) in enumerate(sorted_docs[:top_k])]# 使用示例
def create_ensemble_ranker():"""创建集成排序器示例"""ensemble = EnsembleRanker()# 添加不同的排序器ensemble.add_ranker('cross_encoder', CrossEncoderRanker(), weight=0.4)ensemble.add_ranker('bi_encoder', BiEncoderRanker(), weight=0.3)ensemble.add_ranker('multi_factor', MultiFactorRanker(), weight=0.3)return ensemble

7. 性能优化技巧

7.1 批量处理和缓存

import functools
from functools import lru_cache
import hashlibclass OptimizedRanker:def __init__(self, cache_size=1000):self.cache_size = cache_sizeself.embedding_cache = {}@lru_cache(maxsize=1000)def cached_similarity(self, query_hash, doc_hash):"""缓存相似度计算结果"""# 这里应该是实际的相似度计算passdef get_query_hash(self, query):"""生成查询的哈希值"""return hashlib.md5(query.encode()).hexdigest()def get_doc_hash(self, document):"""生成文档的哈希值"""return hashlib.md5(document.encode()).hexdigest()def batch_rank_documents(self, queries, documents_list, top_k=10):"""批量处理多个查询"""results = []# 批量编码所有唯一的查询和文档unique_queries = list(set(queries))all_documents = list(set([doc for docs in documents_list for doc in docs]))# 这里可以使用批量编码来提高效率# query_embeddings = self.model.encode(unique_queries)# doc_embeddings = self.model.encode(all_documents)for i, (query, documents) in enumerate(zip(queries, documents_list)):# 使用预计算的嵌入进行排序ranking = self.rank_single_query(query, documents, top_k)results.append(ranking)return results

总结

这些排序方法各有优缺点:

  1. 基础方法(余弦相似度、欧几里得距离):计算简单,适合实时场景
  2. 深度学习方法(Cross-Encoder、Bi-Encoder):效果好但计算开销大
  3. 混合方法:平衡多个因素,更全面但复杂度高
  4. 高级策略(MMR、图排序):考虑多样性和文档间关系
  5. 自适应方法:能够学习用户偏好,持续改进

在实际应用中,建议:

  • 先用简单方法建立基线
  • 根据业务需求选择合适的排序策略
  • 使用A/B测试验证效果
  • 考虑计算资源和响应时间的平衡
  • 实施缓存和批量处理来优化性能

使用案例解释

🔍 1. query (查询文本)

  • 类型: str
  • 示例: "人工智能在医疗领域有哪些应用?"
  • 特点: 用户的原始问题,通常较短(10-100字符)
  • 作用: 表达用户的信息需求

📚 2. documents (候选文档列表)

  • 类型: List[str]
  • 示例:
documents = ["人工智能在医疗诊断中发挥重要作用,通过机器学习算法...","医疗AI还广泛应用于药物发现领域,通过深度学习模型...","AI在医疗健康管理中的应用包括个性化治疗方案制定..."
]
  • 特点: 来自初步检索的候选文档,需要重新排序
  • 长度: 通常每个文档几十到几百个字符

🧮 3. query_embedding (查询嵌入向量)

  • 类型: np.ndarraytorch.Tensor
  • 形状: (768,)(384,) 等,取决于模型
  • 示例: array([0.1234, -0.5678, 0.9012, ...])
  • 特点: 查询文本的数值化向量表示,包含语义信息
  • 生成方式: 通过BERT、Sentence-BERT等模型编码得到

📊 4. doc_embeddings (文档嵌入向量矩阵)

  • 类型: np.ndarraytorch.Tensor
  • 形状: (文档数量, 向量维度)(5, 768)
  • 示例:
doc_embeddings = array([[0.1111, -0.2222, 0.3333, ...],  # 文档1的向量[0.4444, -0.5555, 0.6666, ...],  # 文档2的向量[0.7777, -0.8888, 0.9999, ...]   # 文档3的向量
])
  • 特点: 所有候选文档的向量表示矩阵

🔄 数据流转过程

  1. 输入阶段: 接收query文本和documents列表
  2. 编码阶段: 将文本转换为query_embeddingdoc_embeddings
  3. 计算阶段: 计算查询向量与文档向量的相似度
  4. 排序阶段: 根据相似度分数重新排序文档
  5. 输出阶段: 返回排序后的文档列表

💡 为什么需要这4个参数?

  • query + documents: 提供原始文本,便于最终展示和理解
  • query_embedding + doc_embeddings: 提供数值化表示,便于计算相似度

这种设计既保留了文本的可读性,又利用了向量的计算效率,是RAG系统中的标准做法。不同的排序算法可能只使用其中部分参数,但这4个参数涵盖了大多数排序方法的需求。

相关文章:

  • 基于PHP的连锁酒店管理系统
  • 英国云服务器上安装宝塔面板(BT Panel)
  • cie数通的含金量高吗?费用多少?
  • MySQL--慢查询日志、日志分析工具mysqldumpslow
  • 由于 z(x,y) 的变化导致的影响(那部分被分给了链式项)
  • 动画直播如何颠覆传统?解析足球篮球赛事的数据可视化革命
  • 深度剖析OpenSSL心脏滴血漏洞与Struts2远程命令执行漏洞
  • ShuffleNet 改进:与通道注意力机制(CAM)的结合实现
  • python报错 ModuleNotFoundError: No module named ‘Crypto‘
  • SpringAI实战:ChatModel智能对话全解
  • [Linux] 命令行管理文件
  • Spring Boot 启动流程详解
  • 安装便捷、维护省心,强力巨彩租赁屏助力视觉体验升级
  • LeetCode - 647. 回文子串
  • 求问,PMP属于职称认证吗?
  • PH热榜 | 2025-06-07
  • Redux Toolkit 快速入门指南:createSlice、configureStore、useSelector、useDispatch 全面解析
  • eNSP-IP数据包分析
  • (纳芯微)NST86-DSCR 精度±0.5℃,低功耗模拟输出温度传感器(-10.9mV/℃)负温度系数
  • CMIP6气候模式资料概览
  • 工程网站建设方案/网站关键词怎么优化排名
  • 经营性网站icp/网址大全123
  • 建网站哪个好/怎样创建网站平台
  • 做爰在线网站/竞价排名的弊端
  • 网站建设所需要的内容/百度网址大全电脑版旧版本
  • 重庆的企业网站/快手seo软件下载