用MTEB对Embedding模型进行benchmark
在开发RAG
应用是,我们经常需要对使用的Embedding模型来进行一个有效的评估,来确定Embedding model在当前的domain和task中,能有较好的performance.
一个直接的思路是准备一个evaluation数据集,在特定任务下,对候选的embedding model进行测试,统计其评分,通过分析各个指标确认是否符合当前使用场景的需要.
准备evaluation数据集是一个费时费力的过程,非常幸运的是,已经有开源的python lib做了这个事情,这就是本文要介绍的MTEB
.
MTEB
MTEB
(Massive Text Embedding Benchmark) 是一个embedding benchmark 数据集,同时也是一个执行benchmark的工具,通过此工具,可以非常方便的对模型进行评估,甚至只需要几行代码.
MTEB开源仓库地址为:https://github.com/embeddings-benchmark/mteb
安装
安装MTEB
直接使用pip
pip install mteb
如果访问pypi公共仓库较慢,可以使用mirror
pip install mteb -i https://mirrors.aliyun.com/pypi/simple
benchmark
执行benchmark非常简单,我们用官方的例子进行说明:
import mteb
from sentence_transformers import SentenceTransformermodel_name = "average_word_embeddings_komninos"# 加载模型
model = mteb.get_model(model_name)
tasks = mteb.get_tasks(tasks=["Banking77Classification"])
evaluation = mteb.MTEB(tasks=tasks)
# 执行evaluation
results = evaluation.run(model, output_folder=f"results/{model_name}")
可以看到,在例子中,用sentence transformers库中的模型作为embedding model, 并且选取了Banking77Classification
这个任务进行跑分。这里补充一个知识点,在MTEB库中,有很多已经定义好的task,这些task可以分为多个任务类型,例如有些是做retrieval,有些是做cluster, 有些是classification等,每个任务类型下的任务,根据不同的domain数据集,建立了[type,domain]这样关系的task. 例如我希望测试retrieval这个任务类型下,在medical领域的跑分,我就可以找到特定的任务,进行evaluation.
在跑分后,我们可以在output_folder
参数指定的目录中找到结果,结果是一个json结构的数据,对多个指标进行了评分,通过分析这些指标,可以评估模型是否符合场景预期.
custom model
如果只能使用 sentence transformers中的模型,就大大限制了这个工具的使用范围,很多时候我们会训练/微调特定的领域embedding模型来实现更好的效果,如果对这类custom embedding model进行evaluation该怎么处理呢?
其实很容易,对custom embedding model进行一个wrapper,让wrapper后的模型提供encode等方法就可以,具体本文不再赘述.
可以参考以下的snippets
import logging
from typing import Any, Callable, Dict, List, Literal, Type, Unionimport numpy as np
import torchfrom mteb.encoder_interface import Encoder
from mteb.model_meta import ModelMeta
from mteb.models.text_formatting_utils import corpus_to_textsfrom .instructions import task_to_instructionlogging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)EncodeTypes = Literal["query", "passage"]def llm2vec_instruction(instruction):if len(instruction) > 0 and instruction[-1] != ":":instruction = instruction.strip(".") + ":"return instructionclass LLM2VecWrapper:def __init__(self, *args, **kwargs):try:from llm2vec import LLM2Vecexcept ImportError:raise ImportError("To use the LLM2Vec models `llm2vec` is required. Please install it with `pip install llm2vec`.")extra_kwargs = {}try:import flash_attn # noqaextra_kwargs["attn_implementation"] = "flash_attention_2"except ImportError:logger.warning("LLM2Vec models were trained with flash attention enabled. For optimal performance, please install the `flash_attn` package with `pip install flash-attn --no-build-isolation`.")self.task_to_instructions = Noneif "task_to_instructions" in kwargs:self.task_to_instructions = kwargs.pop("task_to_instructions")if "device" in kwargs:kwargs["device_map"] = kwargs.pop("device")elif torch.cuda.device_count() > 1:# bug fix for multi-gpukwargs["device_map"] = Noneself.model = LLM2Vec.from_pretrained(*args, **extra_kwargs, **kwargs)def encode(self,sentences: List[str],*,prompt_name: str = None,**kwargs: Any, # noqa) -> np.ndarray:if prompt_name is not None:instruction = (self.task_to_instructions[prompt_name]if self.task_to_instructionsand prompt_name in self.task_to_instructionselse llm2vec_instruction(task_to_instruction(prompt_name)))else:instruction = ""sentences = [[instruction, sentence] for sentence in sentences]return self.model.encode(sentences, **kwargs)def encode_corpus(self,corpus: Union[List[Dict[str, str]], Dict[str, List[str]], List[str]],prompt_name: str = None,**kwargs: Any,) -> np.ndarray:sentences = corpus_to_texts(corpus, sep=" ")sentences = [["", sentence] for sentence in sentences]return self.model.encode(sentences, **kwargs)def encode_queries(self, queries: List[str], **kwargs: Any) -> np.ndarray:return self.encode(queries, **kwargs)def _loader(wrapper: Type[LLM2VecWrapper], **kwargs) -> Callable[..., Encoder]:_kwargs = kwargsdef loader_inner(**kwargs: Any) -> Encoder:return wrapper(**_kwargs, **kwargs)return loader_innerllm2vec_llama3_8b_supervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",languages=["eng_Latn"],open_source=True,revision=None, # TODO: Not sure what to put here as a model is made of two peft repos, each with a different revisionrelease_date="2024-04-09",
)llm2vec_llama3_8b_unsupervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_mistral7b_supervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_mistral7b_unsupervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_llama2_7b_supervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-supervised",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-supervised",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_llama2_7b_unsupervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-unsup-simcse",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Llama-2-7b-chat-hf-mntp-unsup-simcse",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_sheared_llama_supervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-supervised",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-supervised",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)llm2vec_sheared_llama_unsupervised = ModelMeta(loader=_loader(LLM2VecWrapper,base_model_name_or_path="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp",peft_model_name_or_path="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse",device_map="auto",torch_dtype=torch.bfloat16,),name="McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse",languages=["eng_Latn"],open_source=True,revision=None,release_date="2024-04-09",
)