当前位置: 首页 > news >正文

用MTEB对Embedding模型进行benchmark

在开发RAG应用是,我们经常需要对使用的Embedding模型来进行一个有效的评估,来确定Embedding model在当前的domain和task中,能有较好的performance.

一个直接的思路是准备一个evaluation数据集,在特定任务下,对候选的embedding model进行测试,统计其评分,通过分析各个指标确认是否符合当前使用场景的需要.

准备evaluation数据集是一个费时费力的过程,非常幸运的是,已经有开源的python lib做了这个事情,这就是本文要介绍的MTEB.

MTEB

MTEB(Massive Text Embedding Benchmark) 是一个embedding benchmark 数据集,同时也是一个执行benchmark的工具,通过此工具,可以非常方便的对模型进行评估,甚至只需要几行代码.

MTEB开源仓库地址为:https://github.com/embeddings-benchmark/mteb

安装

安装MTEB 直接使用pip

pip install mteb

如果访问pypi公共仓库较慢,可以使用mirror

pip install mteb -i https://mirrors.aliyun.com/pypi/simple

benchmark

执行benchmark非常简单,我们用官方的例子进行说明:

import mteb
from sentence_transformers import SentenceTransformermodel_name = "average_word_embeddings_komninos"# 加载模型
model = mteb.get_model(model_name)
tasks = mteb.get_tasks(tasks=["Banking77Classification"])
evaluation = mteb.MTEB(tasks=tasks)
# 执行evaluation
results = evaluation.run(model, output_folder=f"results/{model_name}")

可以看到,在例子中,用sentence transformers库中的模型作为embedding model, 并且选取了Banking77Classification
这个任务进行跑分。这里补充一个知识点,在MTEB库中,有很多已经定义好的task,这些task可以分为多个任务类型,例如有些是做retrieval,有些是做cluster, 有些是classification等,每个任务类型下的任务,根据不同的domain数据集,建立了[type,domain]这样关系的task. 例如我希望测试retrieval这个任务类型下,在medical领域的跑分,我就可以找到特定的任务,进行evaluation.

在跑分后,我们可以在output_folder参数指定的目录中找到结果,结果是一个json结构的数据,对多个指标进行了评分,通过分析这些指标,可以评估模型是否符合场景预期.

custom model

如果只能使用 sentence transformers中的模型,就大大限制了这个工具的使用范围,很多时候我们会训练/微调特定的领域embedding模型来实现更好的效果,如果对这类custom embedding model进行evaluation该怎么处理呢?

其实很容易,对custom embedding model进行一个wrapper,让wrapper后的模型提供encode等方法就可以,具体本文不再赘述.

可以参考以下的snippets

import logging
from typing import Any, Callable, Dict, List, Literal, Type, Unionimport numpy as np
import torchfrom mteb.encoder_interface import Encoder
from mteb.model_meta import ModelMeta
from mteb.models.text_formatting_utils import corpus_to_textsfrom .instructions import task_to_instructionlogging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)EncodeTypes = Literal["query", "passage"]def llm2vec_instruction(instruction):if len(instruction) > 0 and instruction[-1] != ":":instruction = instruction.strip(".") + ":"return instructionclass LLM2VecWrapper:def __init__(self, *args, **kwargs):try:from llm2vec import LLM2Vecexcept ImportError:raise ImportError("To use the LLM2Vec models `llm2vec` is required. Please install it with `pip install llm2vec`.")extra_kwargs = {}try:import flash_attn  # noqaextra_kwargs["attn_implementation"] = "flash_attention_2"except ImportError:logger.warning("LLM2Vec models were trained with flash attention enabled. For optimal performance, please install the `flash_attn` package with `pip install flash-attn --no-build-isolation`.")self.task_to_instructions = Noneif "task_to_instructions" in kwargs:self.task_to_instructions = kwargs.pop("task_to_instructions")if "device" in kwargs:kwargs["device_map"] = kwargs.pop("device")elif torch.cuda.device_count() > 1:# bug fix for multi-gpukwargs["device_map"] = Noneself.model = LLM2Vec.from_pretrained(*args, **extra_kwargs, **kwargs)def encode(self,sentences: List[str],*,prompt_name: str = None,**kwargs: Any,  # noqa) -> np.ndarray:if prompt_name is not None:instruction = (self.task_to_instructions[prompt_name]if self.task_to_instructionsand prompt_name in self.task_to_instructionselse llm2vec_instruction(task_to_instruction(prompt_name)))else:instruction = ""sentences = [[instruction, sentence] for sentence in sentences]return self.model.encode(sentences, **kwargs)def encode_corpus(self,corpus: Union[List[Dict[str, str]], Dict[str, List[str]], List[str]],prompt_name: str = None,**kwargs: Any,) -> np.ndarray:sentences = corpus_to_texts(corpus, sep=" ")sentences = [["", sentence] for sentence in sentences]return self.model.encode(sentences, **kwargs)def encode_queries(self, queries: List[str], **kwargs: Any) -> np.ndarray:return self.encode(queries, **kwargs)def _loader(wrapper: Type[LLM2VecWrapper], **kwargs) -> Callable[..., Encoder]:_kwargs = kwargsdef loader_inner(**kwargs: Any) -> Encoder:return wrapper(**_kwargs, **kwargs)return loader_innerllm2vec_llama3_8b_supervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",languages=["eng_Latn"],open_source=True,revision=None,  # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revisionrelease_date="2024-04-09",
)llm2vec_llama3_8b_unsupervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_mistral7b_supervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_mistral7b_unsupervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_llama2_7b_supervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-supervised",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-supervised",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_llama2_7b_unsupervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-unsup-simcse",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-unsup-simcse",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_sheared_llama_supervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-supervised",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-supervised",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_sheared_llama_unsupervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)
http://www.dtcms.com/a/336305.html

相关文章:

  • Pell数列
  • 基本的设计原则
  • SONiC (4) - redis的介绍以及应用
  • 远程协作绘图:借助 cpolar 内网穿透服务访问 Excalidraw
  • 用PaddleDetection套件训练自己的数据集,PP-YOLO-SOD训练全流程
  • 领域快速入门过程记录之--电力网络
  • ROS常用命令手册
  • # C++ 中的 `string_view` 和 `span`:现代安全视图指南
  • GaussDB常用术语缩写及释义
  • 【Linux】IO多路复用
  • nodejs 错误处理
  • Shell脚本-条件判断相关参数
  • 任务型Agent架构简介
  • JUC并发编程04 - 同步/syn-ed(01)
  • prototype 和 _ _ proto _ _的关联
  • 计算机网络 OSI 七层模型和 TCP 五层模型
  • 【Linux系列】如何在 Linux 服务器上快速获取公网
  • 遥感数据介绍——MODIS、VIIRS、Sentinel-2
  • 飞算JavaAI结合Redis实现高性能存储:从数据瓶颈到极速读写的实战之旅
  • 三种变量类型在局部与全局作用域的区别
  • 大模型算法岗面试准备经验分享
  • 【Linux网络编程】NAT、代理服务、内网穿透
  • css中 hsl() 的用法
  • Java-I18n
  • 43 C++ STL模板库12-容器4-容器适配器-堆栈(stack)
  • 百度笔试编程题 选数
  • PWM控制LED亮度:用户态驱动开发详解
  • Soundraw - 你的AI音乐生成器
  • 51单片机-驱动静态数码管和动态数码管模块
  • linux线程被中断打断,不会计入调度次数