当前位置: 首页 > news >正文

LLM监督微调SFT实战指南(Qwen3-0.6B-Base)

学习参考资料:DataWhale SFT教程资料

关于LLM的一些核心概念

LLM训练一般有两个阶段:

  1. 预训练 (Pre-training): 在海量无标签文本数据上进行,让模型学习语言的基本规律和知识,能够预测下一个词。这个阶段耗时久、成本高。
    2. 后训练 (Post-training) / 微调 (Fine-tuning): 在规模较小的、有特定任务标签的数据集上进行,让模型学会遵循指令、执行特定任务。这个阶段速度快、成本低。

其中监督微调 (Supervised Fine-Tuning, SFT) 是一种后训练方法,属于监督学习或模仿学习。它通过使用带标签的“提示-回答”(prompt-response) 数据对进行训练,目标是让模型学会根据给定的提示,生成理想的回答。
在这里插入图片描述

SFT的关键点

  • 数据质量至关重要: SFT的效果很大程度上取决于训练数据的质量。高质量的数据能够引导模型学习到正确的行为和知识。
  • 数据结构: SFT需要的数据结构是“提示-回答”对。
  • 与其他后训练方法的比较:
    • 直接偏好优化 (Direct Preference Optimization, DPO): 通过向模型展示“好”答案和“坏”答案来进行训练,让模型更倾向于生成“好”答案。
    • 在线强化学习 (Online Reinforcement Learning, RL): 模型生成回答后,由奖励函数对回答质量进行评分,模型根据评分进行更新。
  • 评估: 在进行后训练之前和之后,都需要有一套评估体系来跟踪模型的性能,确保模型在各个方面都表现良好。

SFT的应用

  • 将基础模型转变为指令遵循模型。
  • 提升模型在特定任务上的表现,例如:问答、代码生成、数学推理、对话等

基于Qwen3-0.6B-Base的SFT实践

首先导入需要的库:

import torch
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig

接下来定义一个完整的模型推理函数 generate_responses。它的核心功能是接收一个语言模型、分词器和用户的消息,然后引导模型生成一个确定性的、格式干净的文本回答

# 定义一个名为 generate_responses 的函数
# 参数包括:
#   model: 预加载的语言模型
#   tokenizer: 与模型配套的分词器
#   user_message: 用户的输入消息 (字符串)
#   system_message: (可选) 系统提示,用于设定模型的角色或行为
#   max_new_tokens: (可选) 模型生成的最大新词元数量,默认为100
def generate_responses(model, tokenizer, user_message, system_message=None, max_new_tokens=100):# --- 1. 构建对话历史 ---# 初始化一个空列表,用于存放对话消息messages = []# 如果提供了 system_message,就将其作为第一条消息添加到列表中# 这条消息的角色是 "system",内容是 system_messageif system_message:messages.append({"role": "system", "content": system_message})# 添加用户的消息到列表中# 这条消息的角色是 "user",内容是 user_message# 假设总是单轮对话messages.append({"role": "user", "content": user_message})# --- 2. 应用聊天模板 ---# 使用分词器的 apply_chat_template 方法# 这个方法会将上面构建的 messages 列表(包含角色和内容)# 转换成一个符合特定模型预训练格式的、完整的字符串 promptprompt = tokenizer.apply_chat_template(messages,tokenize=False,            # 设置为 False,表示返回一个字符串,而不是直接分词后的 IDadd_generation_prompt=True, # 设置为 True,会自动在末尾添加提示,告诉模型应该开始生成回答了# 例如,可能会加上 "ASSISTANT:" 或类似的标识符enable_thinking=False,     # 禁用某些模型可能支持的“思考”过程的特殊 token)# --- 3. 分词与张量转换 ---# 将格式化后的 prompt 字符串进行分词,并转换为 PyTorch 张量 (tensors)# return_tensors="pt" 表示返回 PyTorch (pt) 格式的张量# .to(model.device) 将这些张量移动到模型所在的设备上(例如 CPU 或 GPU),以避免设备不匹配的错误inputs = tokenizer(prompt, return_tensors="pt").to(model.device)# --- 4. 模型推理生成 ---# 使用 torch.no_grad() 上下文管理器,这会禁用梯度计算# 在推理(非训练)阶段,这样做可以显著减少内存消耗并加快计算速度with torch.no_grad():# 调用模型的 generate 方法来生成文本outputs = model.generate(**inputs,                     # `**inputs` 将 inputs 字典解包,传入 `input_ids` 等参数max_new_tokens=max_new_tokens, # 限制生成内容的最大长度do_sample=False,               # 设置为 False 表示不进行采样,而是使用贪心解码 (greedy decoding)# 每次都选择概率最高的词元,这会让输出结果固定、可复现pad_token_id=tokenizer.eos_token_id, # 在需要填充时,使用句末符 (eos_token) 的 IDeos_token_id=tokenizer.eos_token_id, # 明确指定句末符的 ID)# --- 5. 解码与后处理 ---# 获取输入部分的长度(单位是词元/token 的数量)input_len = inputs["input_ids"].shape[1]# 从模型的总输出 `outputs` 中,切片掉输入部分,只保留新生成的部分# `outputs[0]` 是因为 generate 的输出可能是一个批次,我们取第一个(也是唯一一个)结果generated_ids = outputs[0][input_len:]# 使用分词器的 decode 方法,将新生成的词元 ID 列表转换回人类可读的字符串# skip_special_tokens=True 会移除解码结果中特殊的 token,例如 <|endoftext|> 等# .strip() 用于移除字符串开头和结尾的空白字符response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()# --- 6. 返回结果 ---# 返回最终处理好的、干净的文本响应return response

再写一个通过我们预设的问题来测试模型的代码:

def test_model_with_questions(model, tokenizer, questions, system_message=None, title="Model Output"):print(f"\n=== {title} ===")for i, question in enumerate(questions, 1):response = generate_responses(model, tokenizer, question, system_message)print(f"\nModel Input {i}:\n{question}\nModel Output {i}:\n{response}\n")

下面是加载模型的代码:

# 导入transformer中对应的库
from transformers import AutoTokenizer, AutoModelForCausalLM# 定义一个函数,用于加载指定名称的模型和分词器
# 参数:
#   model_name: 字符串,指定要从 Hugging Face Hub 加载的模型名称,例如 "meta-llama/Llama-2-7b-chat-hf"
#   use_gpu: 布尔值,如果为 True,则尝试将模型加载到 GPU 上
def load_model_and_tokenizer(model_name, use_gpu=False):# --- 1. 加载模型和分词器 ---# 从 Hugging Face Hub 下载并加载预训练好的分词器tokenizer = AutoTokenizer.from_pretrained(model_name)# 从 Hugging Face Hub 下载并加载预训练好的因果语言模型 (Causal LM)model = AutoModelForCausalLM.from_pretrained(model_name)# --- 2. GPU 配置 ---# 检查是否需要使用 GPUif use_gpu:model.to("cuda")# --- 3. 配置聊天模板 ---# 检查分词器是否已经定义了聊天模板 (chat_template)# 聊天模板用于将多轮对话历史(如系统、用户、助手的消息)格式化为单个字符串if not tokenizer.chat_template:# 如果没有定义,就设置一个默认的模板tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}System: {{ message['content'] }}\n{% elif message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }} <|endoftext|>{% endif %}{% endfor %}"""# --- 4. 配置填充符 (Padding Token) ---# 检查分词器是否定义了填充符 (pad_token)# 填充符用于在处理批量数据时,将较短的序列填充到与最长序列相同的长度if not tokenizer.pad_token:# 如果没有定义,就将句末符 (eos_token) 设置为填充符# 这是一个常见的做法,可以确保模型在处理填充部分时知道这是序列的结束tokenizer.pad_token = tokenizer.eos_token# --- 5. 返回结果 ---# 返回加载并配置好的模型和分词器对象return model, tokenizer

下面的代码用于展示我们要使用的数据集来对模型进行SFT:

def display_dataset(dataset):# Visualize the dataset rows = []for i in range(3):example = dataset[i]user_msg = next(m['content'] for m in example['messages']if m['role'] == 'user')assistant_msg = next(m['content'] for m in example['messages']if m['role'] == 'assistant')rows.append({'User Prompt': user_msg,'Assistant Response': assistant_msg})# Display as tabledf = pd.DataFrame(rows)pd.set_option('display.max_colwidth', None)  # Avoid truncating long stringsdisplay(df)

首先在没有经过微调的Qwen3-0.6B-Base模型上进行问题测试,从测试结果上来看,模型输出的都是乱码,说明目前模型还并没有对话的能力

USE_GPU = True
questions = ["Give me an 1-sentence introduction of LLM.","Calculate 1+1-1","What's the difference between thread and process?"
]
model, tokenizer = load_model_and_tokenizer("./models/Qwen3-0.6B-Base", USE_GPU)test_model_with_questions(model, tokenizer, questions, title="Base Model (Before SFT) Output")del model, tokenizer

在这里插入图片描述
我们选择一个小型的对话数据集来对模型进行微调,将这个文件下载到本地进行加载:
在这里插入图片描述

train_dataset = load_dataset("./data")["train"]
if not USE_GPU:train_dataset=train_dataset.select(range(100))display_dataset(train_dataset)


接下来设置为调参数并进行训练:

# SFTTrainer config 
sft_config = SFTConfig(learning_rate=8e-5, # Learning rate for training. num_train_epochs=1, #  Set the number of epochs to train the model.per_device_train_batch_size=1, # Batch size for each device (e.g., GPU) during training. gradient_accumulation_steps=8, # Number of steps before performing a backward/update pass to accumulate gradients.gradient_checkpointing=False, # Enable gradient checkpointing to reduce memory usage during training at the cost of slower training speed.logging_steps=2,  # Frequency of logging training progress (log every 2 steps).)
sft_trainer = SFTTrainer(model=model,args=sft_config,train_dataset=train_dataset, processing_class=tokenizer,
)
sft_trainer.train()

训练结果如下:在这里插入图片描述
最后来对模型进行测试:

if not USE_GPU: # move model to CPU when GPU isn’t requestedsft_trainer.model.to("cpu")
USE_GPU = True
model_name = "./trainer_output/checkpoint-13"
model, tokenizer = load_model_and_tokenizer(model_name, USE_GPU)
test_model_with_questions(sft_trainer.model, tokenizer, questions, title="Base Model (After SFT) Output")

在这里插入图片描述
可以看出大模型已经基本具备了回答问题的能力,但是在做数学题的时候还是出现了回答重复的问题,为了测试模型的数学能力,我有添加了一个乘法算术题,看看能否回答。
在这里插入图片描述
这回直接不演了,看来计算还是不太会,以后有时间试试能不能在一个数学运算SFT数据集上微调一下有没有效果吧

http://www.dtcms.com/a/486079.html

相关文章:

  • 【基础算法】多源 BFS
  • *@UI 视角下主程序与子程序的菜单页面架构及关联设计
  • Virtio 半虚拟化技术解析
  • 网站设计怎么好看律师做网络推广哪个网站好
  • 用commons vfs 框架 替换具体的sftp 实现
  • 网站模板怎么设计软件wordpress多重筛选页面
  • 通往Docker之路:从单机到容器编排的架构演进全景
  • 分布式链路追踪:微服务可观测性的核心支柱
  • PostgreSQL 函数ARRAY_AGG详解
  • 【OpenHarmony】MSDP设备状态感知模块架构
  • RAG 多模态 API 处理系统设计解析:企业级大模型集成架构实战
  • 通过一个typescript的小游戏,使用单元测试实战(二)
  • 多物理域协同 + 三维 CAD 联动!ADS 2025 解锁射频前端、天线设计新体验
  • 前端微服务架构解析:qiankun 运行原理详解
  • linux ssh config详解
  • 内网攻防实战图谱:从红队视角构建安全对抗体系
  • 鲲鹏ARM服务器配置YUM源
  • 网站分类标准沈阳网站制作招聘网
  • 建设一个网站需要几个角色建筑工程网课心得体会
  • 基于Robosuite和Robomimic采集mujoco平台的机械臂数据微调预训练PI0模型,实现快速训练机械臂任务
  • 深度学习目标检测项目
  • SQL 窗口函数
  • 盟接之桥浅谈目标落地的底层逻辑:实践、分解与认知跃迁
  • 【Qt】4.项目文件解析
  • Redis-布隆过滤器BloomFilter
  • 网站建设找至尚网络深圳制作企业网站
  • 网页是网站吗苏州刚刚发生的大事
  • WPF中RelayCommand的实现与使用详解
  • 百度天气:空气质量WebGIS可视化的创新实践 —— 以湖南省为例
  • Flutter---GridView+自定义控件