当前位置: 首页 > news >正文

Qlora+DPO微调Qwen2.5

一、训练

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
import torch# 模型路径
model_name = "Qwen/Qwen2.5-7B-Instruct"# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = tokenizer.eos_tokenmodel = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,device_map="auto"
)# # 加载模型(4-bit 量化)
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     load_in_4bit=True,
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True,
#     cache_dir="./models"
# )# 准备模型 + LoRA
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(r=8,lora_alpha=16,target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)# 示例偏好数据集(需格式为 prompt/chosen/rejected)dataset = load_dataset("json", data_files="datasets/dpo.jsonl")["train"]# 训练配置
training_args = DPOConfig(output_dir="./dpo_qwen_7b_qlora",per_device_train_batch_size=1,gradient_accumulation_steps=8,learning_rate=5e-5,max_length=1024,beta=0.1,num_train_epochs=1,logging_steps=10,save_steps=100,fp16=False,bf16=True,optim="adamw_torch",report_to=None,
)# 启动训练
trainer = DPOTrainer(model=model,ref_model=None,args=training_args,tokenizer=tokenizer,train_dataset=dataset,
)trainer.train()
trainer.save_model("./dpo_qwen_7b_qlora")

Tips(DPO数据集格式):
dpo.jsonl(实例)

{"prompt": "朋友发消息:“我又开始减肥了,这次一定成功!”\n机器人回复:", "chosen": "太好了,我已经把这句话截图存档,下次你半夜点烧烤的时候我就做成表情包发你。", "rejected": "加油,相信你这次一定能坚持住的!"}
{"prompt": "用户问:“为什么我总是存不下钱?”\n机器人回复:", "chosen": "因为钱包和外卖 App 签了互不侵犯条约,而你刚好是那片战场。", "rejected": "可能是消费习惯没控制好,建议做预算表。"}
{"prompt": "用户说:“老板夸我 PPT 做得真好。”\n机器人回复:", "chosen": "恭喜,你把 3 小时加班浓缩成 3 页艺术垃圾的功力确实一流。", "rejected": "太棒了,你的努力终于被看见了!"}
{"prompt": "用户吐槽:“地铁又坏了,全勤奖没了。”\n机器人回复:", "chosen": "没事,地铁只是帮你提前体验被公司裁员的感觉,习惯就好。", "rejected": "真倒霉,下次早点出门吧。"}
{"prompt": "用户问:“如何优雅地拒绝加班?”\n机器人回复:", "chosen": "把电脑亮度调到最低,假装它也在过劳死,然后关机祝它安息。", "rejected": "可以直接和领导沟通,说明自己已有安排。"}

二、合并参数

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torchbase_model_path = "/root/autodl-tmp/qwen_7b"     # 原始基础模型
lora_path       = "/root/autodl-tmp/dpo_qwen_7b_qlora/checkpoint-20"# 先加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(base_model_path,torch_dtype=torch.bfloat16,device_map="auto"
)# 再加载并合并 LoRA
model = PeftModel.from_pretrained(base_model, lora_path)
model = model.merge_and_unload()                 # 合并权重
tokenizer = AutoTokenizer.from_pretrained(base_model_path)# 保存合并后的模型(可选,方便以后直接用)
model.save_pretrained("./qwen7b_sarcasm_merged")
tokenizer.save_pretrained("./qwen7b_sarcasm_merged")

三、推理

from transformers import AutoTokenizer, AutoModelForCausalLM
import torchmodel_dir = "./qwen7b_sarcasm_merged"            # 合并后的完整模型 blobs
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model     = AutoModelForCausalLM.from_pretrained(model_dir,torch_dtype=torch.bfloat16,device_map="auto"
)messages = [{"role": "user", "content": "老板夸我 PPT 做得真好。"}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)out = model.generate(**inputs,max_new_tokens=128,do_sample=True,temperature=0.7,top_p=0.9)
print(tokenizer.decode(out[0], skip_special_tokens=True))
http://www.dtcms.com/a/297290.html

相关文章:

  • Python捕获异常
  • Yolov8/Yolov11实例分割训练自有数据集
  • Springboot项目实现将文件上传到阿里云
  • Python实战:数据处理与可视化的奇妙之旅
  • 双指针算法介绍及使用(下)
  • JavaScript 中 let 在循环中的作用域机制解析
  • 没有 Mac,如何上架 iOS App?多项目复用与流程标准化实战分享
  • uniapp使用css实现进度条带动画过渡效果
  • uniapp之微信小程序标题对其右上角按钮胶囊
  • golang怎么实现每秒100万个请求(QPS),相关系统架构设计详解
  • 海康SDK球机精确控制[球机预置点配置]
  • 未来之路 - eBPF 与 Cilium 如何重塑网络
  • 在kdb+x中使用SQL
  • 理解Spring中的IoC
  • 基于新型群智能优化算法的BP神经网络初始权值与偏置优化
  • WPF MVVM进阶系列教程(二、数据验证)
  • Elasticsearch-9.0.4安装教程
  • 【SpringAI实战】实现仿DeepSeek页面对话机器人(支持多模态上传)
  • MySQL-Every derived table must have its own alias
  • OpenRLHF:面向超大语言模型的高性能RLHF训练框架
  • 基于 Nginx 与未来之窗防火墙构建下一代自建动态网络防护体系​—仙盟创梦IDE
  • Java-82 深入浅出 MySQL 内部架构:服务层、存储引擎与文件系统全覆盖
  • 秋招Day19 - 分布式 - 分布式锁
  • 静默的环保革命:Deepoc具身智能如何让垃圾桶读懂垃圾的语言
  • 一道检验编码能力的字符串的题目
  • 进程控制->进程替换(Linux)
  • LLM:Day3
  • 学习嵌入式的第二十九天-数据结构-(2025.7.16)线程控制:互斥与同步
  • 【运维】ubuntu 安装图形化界面
  • 顺应AI浪潮,电科金仓数据库再创辉煌