BGE-large-zh-v1.5微调
安装FlagEmbedding:
pip install -U FlagEmbedding[finetune]
训练数据应该是一个 json 文件,其中每一行都是一个像这样的字典:
{"query": str, "pos": List[str], "neg":List[str], "pos_scores": List[int], "neg_scores": List[int], "prompt": str, "type": str}
- query 是查询。
- pos 是正向文本列表。
- neg 是负向文本列表。
- pos_scores 是查询和 pos 对应的分数列表,neg_scores 是查询和 neg 对应的分数列表,如果不使用知识蒸馏,可以忽略。
- prompt 是查询的提示,它会覆盖 query_instruction_for_retrieval。
- type 用于 bge-en-icl,它包括 normal、symmetric_class、symmetric_clustering 等。
如果查询没有负向文本,则可以从整个语料库中随机抽取一些作为负向文本。
示例:
{"query": "我的设备不能用了,", "pos": ["您好,我们的严格遵循三包政策:无人为损坏,机器本身质量问题,7天退货,30天换货(用户承担寄过来的运费),1年保修,超出1年付费维修。如需更多帮助,请回复“RG”转人工客服处理。感谢您的配合!"]}
挖掘难负样本(Hard Negatives)命令:
python hn_mine.py \
--input_file /data1/tlw/Embedding_Finetune/data/bge_train_data.jsonl \
--candidate_pool /data1/tlw/Embedding_Finetune/data/bge_corpus_answers_only.jsonl \
--output_file /data1/tlw/Embedding_Finetune/data/bge_training_data_with_HN.jsonl \
--range_for_sampling 2-200 \
--negative_number 15 \
--use_gpu_for_searching \
--embedder_name_or_path /data1/models/bge-large-zh-v1.5 \
--embedder_model_class encoder-only-base \
--batch_size 256
- input_file :用于微调的 json 数据。此脚本将检索每个query的前 k 个文档,并从前 k
个文档中随机抽取负样本(不包括正样本)。 - candidate_pool:要检索的池。默认值为 None,此脚本将从 input_file
中所有负样本的组合中进行检索。如果提供,则应为一个 jsonl 文件,里面包含所有唯一答案(要去重),每行是一个包含 key为 "text"的字典。如果输入的是候选池,此脚本将从该文件中检索负样本。 示例:
{"text": "您好,如划款失败,请 **确保** 收款卡为注册人名下的一类借记卡且未被银行冻结、注销或挂失,过期等。改卡成功后请关注明日10点新卡自动到账情况。"}
- output_file :保存 JSON 数据的路径,其中包含用于微调的难负样本。
- range_for_sampling:在哪里采样阴性。例如,2-100 表示从 top2-top200 文档中抽取
negative_number 负样本。您可以设置较大的值以降低负样本的难度(例如,将其设置为 60-300 以从 top60-300
段落中采样负片) - negative_number :采样负数的数量。
- use_gpu_for_searching :是否使用 faiss-gpu 来检索负数。
- embedder_name_or_path :嵌入器的名称或路径。
- embedder_model_class:嵌入器模型类。(当使用微调后模型挖掘难负样本需要传入这个参数)
训练命令:
torchrun --nproc_per_node 2 \-m FlagEmbedding.finetune.embedder.encoder_only.base \--model_name_or_path /data1/models/bge-large-zh-v1.5 \--cache_dir ./cache/model \--train_data /data1/tlw/Embedding_Finetune/data/bge_training_data_with_HN.jsonl \--cache_path ./cache/data \--train_group_size 8 \--query_max_len 512 \--passage_max_len 512 \--pad_to_multiple_of 8 \--query_instruction_for_retrieval '为这个句子生成表示以用于检索相关文章:' \--query_instruction_format '{}{}' \--knowledge_distillation False \--output_dir ./finetuned_models/bge-large-en-v1.5-finetuned-0905 \--overwrite_output_dir \--learning_rate 1e-5 \--fp16 \--num_train_epochs 5 \--per_device_train_batch_size 64 \--dataloader_drop_last True \--warmup_ratio 0.1 \--gradient_checkpointing \--deepspeed ../ds_stage0.json \--logging_steps 1 \--save_steps 100 \--negatives_cross_device \--temperature 0.02 \--sentence_pooling_method cls \--normalize_embeddings True \--kd_loss_type kl_div
–query_instruction_for_retrieval ‘为这个句子生成表示以用于检索相关文章:’,训练的时候加上,推理使用的时候也需要加上(只有查询的时候需要加,索引阶段文档向量化不用加)。
训练数据:共11366条问答对,9:1拆分训练测试集。
batch_size=64,两张48G 4090显卡,约占64G显存,训练时长约50分钟。
微调结果:
📊 ANSWER POOL COMPARISONTest-based pool size: 217Global pool size: 496Size ratio (global/test): 2.29x🎯 KEY METRICS COMPARISON
------------------------------------------------------------Metric Test Pool Improvement Global Pool Improvement Differencerecall@1 +0.4776 +0.5462 +0.0686recall@5 +0.3456 +0.4459 +0.1003
recall@10 +0.2656 +0.3509 +0.0853mrr@1 +0.4776 +0.5462 +0.0686mrr@5 +0.4357 +0.5175 +0.0819📈 DETAILED RESULTS COMPARISON
--------------------------------------------------------------------------------Metric Test Pool Original Test Pool Finetuned Global Pool Original Global Pool Finetunedrecall@1 0.3843 0.8619 0.2814 0.8276recall@5 0.6324 0.9780 0.5207 0.9666
recall@10 0.7194 0.9850 0.6297 0.9807mrr@1 0.3843 0.8619 0.2814 0.8276mrr@5 0.4755 0.9111 0.3703 0.8878💡 ANALYSIS & RECOMMENDATIONS
----------------------------------------
Recall@1 improvement:• Test pool: +0.4776• Global pool: +0.5462⚠️ Significant difference detected!The global pool evaluation shows better resultsThis suggests the test-based evaluation might be pessimistic✅ Recommendation: Use global pool evaluation for more realistic performance assessment📊 SIMILARITY DISTRIBUTION COMPARISON
----------------------------------------
Separation (pos_mean - neg_mean) after finetuning:• Test pool: 0.3484• Global pool: 0.3615
参考:https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune/embedder