当前位置: 首页 > news >正文

监督微调(SFT)入门:从理论到动手实践

监督微调(SFT)入门:从理论到动手实践

监督微调(Supervised Fine-Tuning, SFT) 是赋予大型语言模型(LLM)特定技能的关键步骤,它通过模仿“问题-标准答案”的范例,让模型学会遵循指令。本文将从 SFT 的基本概念出发,结合 DataWhale 的 Lesson_3.ipynb 教程,带你完整走一遍 SFT 的理论、环境配置、代码实现与效果验证全流程。
在这里插入图片描述

一、什么是 SFT?为什么要做 SFT?

1. SFT 的核心概念

SFT 是给大模型“上课”的过程。简单来说,就是用少量高质量的“提示(Prompt)-响应(Response)”数据集,对预训练好的大模型进行定向优化

  • 与预训练的区别: 预训练数据量巨大(万亿级),目标是学习通用语言规律;SFT 数据量小(几千到上亿),目标是学习特定任务(如数学解题、代码生成)或规范模型行为。
  • 核心优势: 训练快、成本低,能快速让模型掌握基础指令能力。

2. SFT 的价值与目的

SFT 是 LLM 后训练的第一步,其目的主要有三:

  1. 赋予新技能: 让模型从只会聊天,变成能写代码、解数学题。
  2. 规范模型行为: 确保模型遵循特定格式或语言习惯(如必须用中文回复、必须按步骤解题)。
  3. 快速启动能力: 它是最简单有效的后训练方法,能让模型迅速从“无知”状态进入“可用”状态。

3. SFT 的损失函数(公式讲解)

SFT 的训练目标是最大化模型生成“标准答案”的概率,这通常通过最小化负对数似然损失(Negative Log-Likelihood Loss)来实现。SFT 的损失函数定义如下: L SFT = − 1 K ∑ t = 1 N I ( t ∈ Response ) ⋅ log ⁡ p ( y t ∣ x 1 , . . . , x t ; θ ) \mathcal{L}_{\text{SFT}} = -\frac{1}{K} \sum_{t=1}^{N} \mathbb{I}(t \in \text{Response}) \cdot \log p(y_t | x_1, ..., x_t; \theta) LSFT=K1t=1NI(tResponse)logp(ytx1,...,xt;θ)

其中:

  • L SFT \mathcal{L}_{\text{SFT}} LSFT是 SFT 的损失函数。
  • K K K是响应部分的长度。
  • N N N 是整个序列(提示 + 响应)的长度。
  • I ( t ∈ Response ) \mathbb{I}(t \in \text{Response}) I(tResponse) 是指示函数,当 t t t属于响应部分时为 1,否则为 0。
  • p ( y t ∣ x 1 , . . . , x t ; θ ) p(y_t | x_1, ..., x_t; \theta) p(ytx1,...,xt;θ)是模型在给定前 t t t个 token 的条件下生成第 t t t个 token 的概率。
  • θ \theta θ是模型的参数。

公式理解:仅对响应(Response)部分计算损失

这个公式的核心是指示函数 I ( t ∈ Response ) \mathbb{I}(t \in \text{Response}) I(tResponse)。它确保模型只学习如何生成“助手的标准答案”,而不对用户输入的“提示(Prompt)”部分计算损失。换句话说,模型只对生成的响应部分负责,而忽略输入的提示部分。

示例:理解损失的计算范围

假设训练数据是:

  • Prompt: 用户问:2+3=?
  • Response: 答案是5

整个序列 Token 如下:

  • Prompt: [Token1, Token2, Token3, Token4, Token5]
  • Response: [Token6, Token7, Token8]

其中,Token6 到 Token8 是响应部分。

在这种情况下,损失函数将只对 Token6 到 Token8 计算损失,而忽略 Token1 到 Token5。具体计算如下:

L SFT = − 1 3 ( log ⁡ p ( Token6 ∣ Token1 , Token2 , Token3 , Token4 , Token5 ; θ ) + log ⁡ p ( Token7 ∣ Token1 , Token2 , Token3 , Token4 , Token5 , Token6 ; θ ) + log ⁡ p ( Token8 ∣ Token1 , Token2 , Token3 , Token4 , Token5 , Token6 , Token7 ; θ ) ) \mathcal{L}_{\text{SFT}} = -\frac{1}{3} \left( \log p(\text{Token6} | \text{Token1}, \text{Token2}, \text{Token3}, \text{Token4}, \text{Token5}; \theta) + \log p(\text{Token7} | \text{Token1}, \text{Token2}, \text{Token3}, \text{Token4}, \text{Token5}, \text{Token6}; \theta) + \log p(\text{Token8} | \text{Token1}, \text{Token2}, \text{Token3}, \text{Token4}, \text{Token5}, \text{Token6}, \text{Token7}; \theta) \right) LSFT=31(logp(Token6Token1,Token2,Token3,Token4,Token5;θ)+logp(Token7Token1,Token2,Token3,Token4,Token5,Token6;θ)+logp(Token8Token1,Token2,Token3,Token4,Token5,Token6,Token7;θ))

通过这种机制,SFT 使得模型将全部的学习能力集中在“如何高效且准确地回应指令”上,从而提高了模型的指令遵循能力,同时避免了对输入指令格式的过度拟合。


二、动手实践:基于 Lesson_3.ipynb 的 SFT 全流程

本部分将完全围绕 Lesson_3.ipynb 的代码流程,涵盖环境配置、辅助函数编写、模型加载、数据处理和最终训练。

1. 准备环境与工具配置

SFT 需要 PyTorch、HuggingFace 的 Transformers、Datasets 以及 TRL 库。

# Warning controlimport warnings
warnings.filterwarnings('ignore')# 导入核心库
import torch  # 处理张量、GPU计算的核心
import pandas as pd  # 展示数据集用的表格工具
from datasets import load_dataset, Dataset  # 加载和处理训练数据
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM  # 加载模型、分词器、训练参数配置
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig  # SFT专用训练器和配置# 确保安装了依赖:
!pip install -r requirements.txt

2. SFT 核心辅助函数

为了简化 SFT 流程,我们预先编写四个函数,分别负责模型推理、批量测试、模型加载和数据展示。

A. generate_responses:让模型生成回答

def generate_responses(model, tokenizer, user_message, system_message=None,max_new_tokens=100):# 1. 整理对话格式messages = []if system_message:messages.append({"role": "system", "content": system_message})messages.append({"role": "user", "content": user_message})# 2. 用分词器的模板转成模型能读的“prompt”prompt = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True,enable_thinking=False,)# 3. 把 prompt 转成张量,并移到模型设备inputs = tokenizer(prompt, return_tensors="pt").to(model.device)# 4. 模型生成回答with torch.no_grad():outputs = model.generate(**inputs,max_new_tokens=100,do_sample=False,pad_token_id=tokenizer.eos_token_id,eos_token_id=tokenizer.eos_token_id,)# 5. 提取并解码模型的回答input_len = inputs["input_ids"].shape[1]generated_ids = outputs[0][input_len:]response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()return response

B. test_model_with_questions:批量测试模型

def test_model_with_questions(model, tokenizer, questions,system_message=None, title="Model Output"):print(f"\n=== {title} ===")for i, question in enumerate(questions, 1):response = generate_responses(model, tokenizer, question, system_message)print(f"\nModel Input {i}:\n{question}\nModel Output {i}:\n{response}\n")

C. load_model_and_tokenizer:加载模型和分词器

此函数包含了重要的分词器补全逻辑,确保模型能正确处理对话格式。

def load_model_and_tokenizer(model_name, use_gpu = False):# 1. 加载预训练模型和分词器tokenizer = AutoTokenizer.from_pretrained(model_name)model = AutoModelForCausalLM.from_pretrained(model_name)# 2. 移到 GPU(如果可用)if use_gpu:model.to("cuda")# 3. 补全“对话模板”(确保模型能理解对话格式)if not tokenizer.chat_template:tokenizer.chat_template = """{% for message in messages %}{% if message['role'] == 'system' %}System: {{ message['content'] }}\n{% elif message['role'] == 'user' %}User: {{ message['content'] }}\n{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }} <|endoftext|>{% endif %}{% endfor %}"""# 4. 补全“填充符号”(避免训练时报错)if not tokenizer.pad_token:tokenizer.pad_token = tokenizer.eos_tokenreturn model, tokenizer

D. display_dataset:看数据集长啥样

def display_dataset(dataset):rows = []for i in range(3):  # 只看前3个样本example = dataset[i]# 提取用户消息和助手消息user_msg = next(m['content'] for m in example['messages'] if m['role'] == 'user')assistant_msg = next(m['content'] for m in example['messages'] if m['role'] == 'assistant')rows.append({'User Prompt': user_msg,'Assistant Response': assistant_msg})# 转成表格展示df = pd.DataFrame(rows)pd.set_option('display.max_colwidth', None)display(df)

3. SFT 实操:从小模型开始

为了让没有高性能 GPU 的用户也能跑通,我们选择一个参数量较小的模型 (SmolLM2-135M) 进行 SFT。

步骤 3.1:加载小模型

USE_GPU = False  # 如果有GPU,可以改为 True# 测试问题(用于训前和训后效果对比)
questions = ["Give me an 1-sentence introduction of LLM.","Calculate 1+1-1","What's the difference between thread and process?"
]# 加载小模型(注意:路径需要根据实际下载情况调整)
model_name = "./models/HuggingFaceTB/SmolLM2-135M"
model, tokenizer = load_model_and_tokenizer(model_name, USE_GPU)# 训前模型测试(摸底)
test_model_with_questions(model, tokenizer, questions, title="Base Model (Before SFT) Output")

核心现象: 训前模型对这些指令的回答往往是不准确或不完整的。

步骤 3.2:加载和展示训练数据

SFT 的训练数据是关键,格式必须是“用户问 + 助手答”的对话形式。

# 加载数据集(使用 DataWhale 教程指定的开源数据集)
train_dataset = load_dataset("banghua/DL-SFT-Dataset")["train"]# 非GPU环境下只选前100条(加速演示)if not USE_GPU:train_dataset = train_dataset.select(range(100))# 展示前3个样本,确认数据格式# display_dataset(train_dataset) # 在 Jupyter 环境中运行此行可以打印表格

步骤 3.3:配置 SFT 训练参数

使用 SFTConfig 定义 SFT 的规则,参数设置考虑了 CPU 环境下的运行效率

sft_config = SFTConfig(output_dir="./sft_results",  # 结果保存路径learning_rate=8e-5,  # 学习率num_train_epochs=1,  # 训练轮次per_device_train_batch_size=1,  # 单设备批次大小(适合小内存)gradient_accumulation_steps=8,  # 梯度累积(用小批次模拟大批次)gradient_checkpointing=False,logging_steps=2,  # 日志打印频率
)

步骤 3.4:初始化训练器并开始训练

SFTTrainer 是 TRL 库中专门用于 SFT 的工具,它简化了训练流程。

from trl import SFTTrainersft_trainer = SFTTrainer(model=model,args=sft_config,train_dataset=train_dataset,tokenizer=tokenizer,
)# 开始训练
print("\n--- SFT Training Started ---")
sft_trainer.train()
print("--- SFT Training Finished ---")

核心逻辑: 训练过程中,模型只学习生成“助手响应部分”,忽略用户提示部分,正是损失函数在代码中的体现。

4. 效果验证:测试 SFT 后的模型

训练结束后,用同样的问题再次测试模型,对比效果。

# 非GPU环境下,把训好的模型移到CPUif not USE_GPU:sft_trainer.model.to("cpu")# 测试自己训的模型(与训前效果进行对比)
test_model_with_questions(sft_trainer.model, tokenizer, questions, title="Small Model (After SFT) Output")# 释放内存del sft_trainer.model, tokenizer

核心结论: 尽管使用了小模型和小数据集,但 SFT 后的模型回答会更准确、更贴近指令(例如,计算题能答对、概念能讲清楚),验证了 SFT 的有效性。


三、总结与注意事项

SFT 是 LLM 从“通用语言理解”走向“指令遵循”的关键桥梁。通过 Lesson_3.ipynb 的实践,我们完成了 SFT 的理论认知、环境准备、数据加载、参数配置、模型训练和效果验证的全流程闭环。

核心注意事项:

  1. 数据质量是生命线: SFT 效果完全取决于训练样本的准确性和格式统一性。
  2. 避免过拟合: SFT 有可能让模型在没见过的任务上表现变差(灾难性遗忘)。因此,需要控制训练轮次(Epochs)和学习率。
  3. 硬件要求: 虽然 SFT 比预训练轻量,但训练大模型仍需要 GPU 支持。对于资源有限的用户,选择如本文中的小模型是一个很好的学习起点。

掌握了 SFT,你就能给任何预训练的 LLM 打下“基础技能”,为后续更高级的优化(如 DPO、RLHF)做好准备。

四、 一些错误总结

1. 导入相关工具包遇到错误

在这里插入图片描述

原因:未按照requirements.txt 说明的要求进行安装相关版本的依赖库

2. TypeError: argument of type ‘NoneType’ is not iterable

在这里插入图片描述

原因:transformers版本不兼容,requirements.txt中的版本为4.52.4,改成4.51.0即可

五、附录

1. 模型下载代码

from huggingface_hub import snapshot_download# 下载 HuggingFaceTB/SmolLM2-135M 模型
model_id_smol = "HuggingFaceTB/SmolLM2-135M"
local_dir_smol = "./models/HuggingFaceTB/SmolLM2-135M"snapshot_download(repo_id=model_id_smol,local_dir=local_dir_smol,revision="main",  # 分支/版本号resume_download=True  # 支持断点续传
)# 下载 banghua/Qwen3-0.6B-SFT 模型
model_id_qwen = "banghua/Qwen3-0.6B-SFT"
local_dir_qwen = "./models/banghua/Qwen3-0.6B-SFT"snapshot_download(repo_id=model_id_qwen,local_dir=local_dir_qwen,revision="main",  # 分支/版本号resume_download=True  # 支持断点续传
)# 下载 banghua/Qwen3-0.6B-SFT-Another 模型
model_id_qwen_another = "banghua/Qwen3-0.6B-SFT-Another"
local_dir_qwen_another = "./models/banghua/Qwen3-0.6B-SFT-Another"snapshot_download(repo_id=model_id_qwen_another,local_dir=local_dir_qwen_another,revision="main",  # 分支/版本号resume_download=True  # 支持断点续传
)

2. 运行结果

无GPU,使用CPU(本地电脑GPU与tourch 版本冲突 故而未进行GPU训练)

在这里插入图片描述

在这里插入图片描述

六、参考资料

DataWhale-Post-training-of-LLms
DeepLearning课程

http://www.dtcms.com/a/492837.html

相关文章:

  • 个人做电商网站金华做网站建设公司
  • 网站全屏宽度是多少专业网站建设需要多少钱
  • 网站建设 10万元网站导航网站开发
  • 新开发网站在线设计闪字图片
  • 网站logo怎么做才清晰鹤岗商城网站建设
  • 怎样在谷歌做网站网页设计师使用的是什么的屏幕显示颜色模式
  • 九思OA漏洞检测工具
  • Spring 框架 Bean 管理
  • 昆明网站推广公司企业平台有哪些
  • 网站排名优化多少钱网络营销的特点与方法有哪些
  • 安川焊接机器人智能节气仪
  • 昆明市网络优化案例宁波企业网站排名优化
  • 医生做学分在哪个网站wordpress与帝国cms
  • 陕西金顶建设公司网站wordpress 多层目录
  • 网站是用什么语言写的wordpress如何加跳转
  • Maya建模:使模型对称
  • 学校做网站需要多少钱公司注册资金实缴和认缴有什么区别
  • 自己可以做百度网站吗艺术品拍卖网站源码php
  • 做网站做推广做网站需要什么基础
  • 网站建设公司一站通系统简单互联网宣传推广
  • 建一个商城网站需要多久怎么做网站在里面填字
  • 网站建设与优化及覆盖率方案中国亚马逊跨境电商
  • 自助建站竹子番禺人才网招聘网
  • 推图制作网站网站 内容建设存在的问题
  • 西安网站seo收费旅游电子商务网站开发制作
  • 新世纪建设集团网站外贸网站contact
  • Yolo v3
  • 基于历史故障模式的相似性匹配技术
  • 建设安全网站的目的wordpress 后台路径
  • 企业网站的宣传功能体现在()做网站需要几大模板