当前位置: 首页 > news >正文

Embedding模型微调实战(ms-swift框架)

目录

简介

1. 创建虚拟环境

2 安装ms-swift

3安装其他依赖库

4. 下载数据集

5.开始embedding模型训练

6. 自定义数据格式和对应的Loss类型

(1) infoNCE损失     

(2)余弦相似度损失

(3)对比学习损失

(4).在线对比学习损失

(5)损失函数总结


简介

ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL3、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。

🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、GKD、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM、SGLang和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。

https://github.com/modelscope/ms-swift?tab=readme-ov-file

1. 创建虚拟环境

conda create -n swift_venv python=3.12 -y
conda init bash && source /root/.bashrc
conda activate swift_venv
conda install ipykernel
ipython kernel install --user --name=swift_venv

2 安装ms-swift

#1使用pip安装,把ms-swift作为一个库,安装到anaconda的虚拟环境中
pip install ms-swift -U#2从源码克隆安装(推荐)pip install git+https://github.com/modelscope/ms-swift.git#3从源码安装
git clone https://github.com/modelscope/ms-swift.git
cd ms-swift
pip install -e .

3安装其他依赖库

pip install deepspeed liger-kernel 
pip install scikit-learn
pip install -U sentence-transformers
# 建议科学上⽹后,再执⾏下⾯的命令
pip install flash-attn --no-build-isolation

4. 下载数据集

# 设置下⾯命令,⽆需科学上⽹
export HF_ENDPOINT=https://hf-mirror.com
pip install --upgrade huggingface_hub
huggingface-cli download --repo-type dataset --resume-download microsoft/ms_marco --local-dir 下载保存路径

数据格式:

5.开始embedding模型训练

# 把下面命令,保存为train.sh格式,   运行bash train.sh命令,启动训练CUDA_VISIBLE_DEVICES=0 \
swift sft \--model /data/qwen3_embedding/Qwen3-Embedding-0.6B \--task_type embedding \--train_type full \     #训练模式full lora  --torch_dtype bfloat16 \--num_train_epochs 100 \               #训练轮数--per_device_train_batch_size 16 \   #训练批次--per_device_eval_batch_size 16\--learning_rate 1e-4 \--lora_rank 8 \--lora_alpha 8 \--target_modules all-linear \     #目标模块    all-attention--gradient_accumulation_steps 1 \--eval_steps 1 \--save_steps 1 \--save_total_limit 2 \--logging_steps 5 \--max_length 512 \--output_dir output \        #输出路径--warmup_ratio 0.05 \--dataloader_num_workers 8 \--model_author swift \--model_name swift-robot \--split_dataset_ratio 0.3 \   #train 和val分割比例--dataset /home/dataset \    #数据集路径--loss_type infonce       # 损失函数3种类型   contrastive  cosine_similarity infonce

 使用swift sft --help命令查询有哪些训练设置参数。

6. 自定义数据格式和对应的Loss类型

(1) infoNCE损失     

  --loss_type  infonce

对⽐学习损失函数,最⼤化正样

本对相似度,最⼩化负样本对相似度 .

使⽤批内对⽐学习策略,将同批次内其他样本作为负样本.

数据格式:

[{"query": "如何学习编程?","response": "可以从Python语言开始入门,它语法简单适合初学者。","rejected_response": ["随便看看书就会了", "编程很难学不会的"]},{"query": "推荐一款性价比高的手机","response": "Redmi Note系列在2000元价位段表现均衡,值得考虑。","rejected_response": ["越贵的手机越好", "苹果手机永远是最好的"]}
]

 

(2)余弦相似度损失

 --loss_type cosine_similarity

直接优化预测相似度与真实相似度标签的差异 ,使⽤ MSE 损失计算 ||input_label - cosine_sim(u,v)||_2

数据格式:

[{"query": "A dog is barking loudly.","response": "The canine is making loud barking noises.","label": 0.8},{"query": "Children are playing in the park.","response": "Kids are playing in the playground.","label": 1.0},{"query": "The sun is shining brightly.","response": "Bright sunlight is visible.","label": 0.7}
]

(3)对比学习损失

 --loss_type contrastive

经典的对⽐学习损失,正样本拉近,负样本推远 需要设置 margin 参数。

[{"query": "A dog is barking loudly.","response": "The canine is making loud barking noises.","label": 1},{"query": "Children are playing in the park.","response": "Kids are playing in the playground.","label": 1}]

(4).在线对比学习损失

--loss_type online_contrastive

对⽐学习的改进版本,选择困难正样本和困难负样本 通常⽐标准对⽐学习效果更好。

(5)损失函数总结

 

相关文章:

  • 医疗AI智能基础设施构建:向量数据库矩阵化建设流程分析
  • 领域驱动设计(DDD)【28】之实践或推广DDD的学习
  • 左神算法之矩阵旋转90度
  • <STC32G12K128入门第二十二步>STC32G驱动DS18B20(含代码)
  • IDE/IoT/实践小熊派LiteOS工程配置、编译、烧录、调试(基于 bearpi-iot_std_liteos 源码)
  • 2025.1版本PyCharam找不到已存在的conda虚拟环境
  • 领域驱动设计(DDD)【27】之CQRS四个层面的策略
  • Ubuntu服务器(公网)- Ubuntu客户端(内网)的FRP内网穿透配置教程
  • Spring Cloud 服务追踪实战:使用 Zipkin 构建分布式链路追踪
  • Python爬虫:Requests与Beautiful Soup库详解
  • MATLAB变音系统设计:声音特征变换(男声、女声、童声互转)
  • Windows 环境下设置 RabbitMQ 的 consumer_timeout 参数
  • c# 在sql server 数据库中批插入数据
  • Vivado关联Vscode
  • MAC 地址在 TCP 网络中的全面解析:从基础概念到高级应用
  • 商业行业项目创业计划书PPT模版
  • 打卡day57
  • Ai工具分享(2):Vscode+Cline无限免费的使用教程
  • 跟着AI学习C#之项目实战-电商平台 Day6
  • TCP/UDP协议深度解析(三):TCP流量控制的魔法—滑动窗口、拥塞控制与ACK的智慧