当前位置: 首页 > news >正文

TRL - Transformer Reinforcement Learning SFTTrainer 和 SFTConfig

TRL - Transformer Reinforcement Learning SFTTrainer 和 SFTConfig

flyfish

Name: trl
Version: 0.21.0
Summary: Train transformer language models with reinforcement learning.
Home-page: https://github.com/huggingface/trl

SFTTrainer 的作用就是把微调这个过程变简单:
不用自己写复杂的训练代码(比如怎么加载数据、怎么处理对话格式、怎么保存模型);
它能自动处理各种数据,直接喂给模型;
轻松搭配 PEFT 这类 “省资源” 的工具(不用训模型全部参数,少花钱少占内存);
训练中的日志、保存、断点续训这些琐事也搞定了。

                                                                                                                         |

简单示例

from trl import SFTConfig, SFTTrainer
from datasets import load_datasetdataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(args=training_args,model="Qwen/Qwen2.5-0.5B",train_dataset=dataset,
)
trainer.train()
{'loss': 1.7356, 'grad_norm': 4.851604461669922, 'learning_rate': 1.9969635627530365e-05, 'num_tokens': 59528.0, 'mean_token_accuracy': 0.6122099697589874, 'epoch': 0.01}                  
{'loss': 1.624, 'grad_norm': 6.229388236999512, 'learning_rate': 1.9935897435897437e-05, 'num_tokens': 115219.0, 'mean_token_accuracy': 0.6266456484794617, 'epoch': 0.01}                  
{'loss': 1.4456, 'grad_norm': 5.725496292114258, 'learning_rate': 1.990215924426451e-05, 'num_tokens': 171787.0, 'mean_token_accuracy': 0.6584609568119049, 'epoch': 0.02}                  
{'loss': 1.6019, 'grad_norm': 5.118398189544678, 'learning_rate': 1.986842105263158e-05, 'num_tokens': 226067.0, 'mean_token_accuracy': 0.6274505913257599, 'epoch': 0.02}                  
{'loss': 1.5735, 'grad_norm': 4.509005546569824, 'learning_rate': 1.9834682860998653e-05, 'num_tokens': 284290.0, 'mean_token_accuracy': 0.6260855078697205, 'epoch': 0.03}                 
{'loss': 1.5226, 'grad_norm': 4.884885311126709, 'learning_rate': 1.9800944669365722e-05, 'num_tokens': 338761.0, 'mean_token_accuracy': 0.6402752816677093, 'epoch': 0.03}                 
{'loss': 1.5326, 'grad_norm': 5.511511325836182, 'learning_rate': 1.9767206477732795e-05, 'num_tokens': 397731.0, 'mean_token_accuracy': 0.6352024018764496, 'epoch': 0.04}                 
{'loss': 1.3588, 'grad_norm': 7.149945259094238, 'learning_rate': 1.9733468286099865e-05, 'num_tokens': 451526.0, 'mean_token_accuracy': 0.665032935142517, 'epoch': 0.04}                  
{'loss': 1.4091, 'grad_norm': 4.552429676055908, 'learning_rate': 1.9699730094466938e-05, 'num_tokens': 505287.0, 'mean_token_accuracy': 0.6482574462890625, 'epoch': 0.05}                 
{'loss': 1.4679, 'grad_norm': 4.194477081298828, 'learning_rate': 1.966599190283401e-05, 'num_tokens': 563276.0, 'mean_token_accuracy': 0.6456288278102875, 'epoch': 0.05}                  
{'loss': 1.3072, 'grad_norm': 4.873239994049072, 'learning_rate': 1.963225371120108e-05, 'num_tokens': 623927.0, 'mean_token_accuracy': 0.6719759583473206, 'epoch': 0.06}                  
{'loss': 1.4861, 'grad_norm': 5.325733661651611, 'learning_rate': 1.9598515519568153e-05, 'num_tokens': 678587.0, 'mean_token_accuracy': 0.6462028443813324, 'epoch': 0.06}                 
{'loss': 1.5824, 'grad_norm': 4.714112281799316, 'learning_rate': 1.9564777327935226e-05, 'num_tokens': 736510.0, 'mean_token_accuracy': 0.6345476865768432, 'epoch': 0.07}                 
{'loss': 1.4407, 'grad_norm': 4.857705116271973, 'learning_rate': 1.9531039136302295e-05, 'num_tokens': 794954.0, 'mean_token_accuracy': 0.6477592408657074, 'epoch': 0.07}  

SFTTrainer 中加入 PEFT(参数高效微调)配置

使用 PEFT 库中的配置类(如 LoRA)并将其传递给peft_config参数

from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import torch# 加载数据集
dataset = load_dataset("trl-lib/Capybara", split="train")# 配置PEFT (使用LoRA作为示例)
peft_config = LoraConfig(r=8,  # LoRA注意力维度lora_alpha=32,  # LoRA缩放参数target_modules=[  # Qwen模型的目标模块,不同模型可能不同"q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj"],lora_dropout=0.05,  # Dropout概率bias="none",  # 不训练偏置参数task_type="CAUSAL_LM",  # 任务类型:因果语言模型inference_mode=False  # 训练模式
)# 配置训练参数
training_args = SFTConfig(output_dir="./Qwen2.5-0.5B-SFT-PEFT",  # 输出目录num_train_epochs=3,  # 训练轮数per_device_train_batch_size=4,  # 每个设备的批次大小gradient_accumulation_steps=2,  # 梯度累积步数learning_rate=2e-4,  # 学习率logging_steps=10,  # 日志记录步数save_steps=100,  # 模型保存步数fp16=True,  # 使用混合精度训练optim="paged_adamw_8bit",  # 使用8位优化器节省显存report_to="wandb" if "wandb" in locals() else "none",  # 日志报告方式max_length=1024,
)# 初始化SFTTrainer并传入PEFT配置
trainer = SFTTrainer(args=training_args,model="Qwen/Qwen2.5-0.5B",train_dataset=dataset,peft_config=peft_config,  # 添加PEFT配置)# 开始训练
trainer.train()# 保存PEFT模型
trainer.save_model()# 如果需要,可以将PEFT模型与基础模型合并(推理时使用)
# from peft import AutoPeftModelForCausalLM
# model = AutoPeftModelForCausalLM.from_pretrained(
#     "./Qwen2.5-0.5B-SFT-PEFT",
#     device_map="auto",
#     torch_dtype=torch.bfloat16
# )
# merged_model = model.merge_and_unload()
# merged_model.save_pretrained("./Qwen2.5-0.5B-SFT-merged")
[INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)
{'loss': 1.7566, 'grad_norm': 0.8341224193572998, 'learning_rate': 0.00019969635627530366, 'num_tokens': 59528.0, 'mean_token_accuracy': 0.6122347742319107, 'epoch': 0.01}                 
{'loss': 1.637, 'grad_norm': 0.9207030534744263, 'learning_rate': 0.00019935897435897437, 'num_tokens': 115219.0, 'mean_token_accuracy': 0.6304881483316421, 'epoch': 0.01}                 
{'loss': 1.4695, 'grad_norm': 0.7871270775794983, 'learning_rate': 0.0001990215924426451, 'num_tokens': 171787.0, 'mean_token_accuracy': 0.6504963368177414, 'epoch': 0.02}                 
{'loss': 1.6518, 'grad_norm': 0.7337052226066589, 'learning_rate': 0.0001986842105263158, 'num_tokens': 226067.0, 'mean_token_accuracy': 0.621559776365757, 'epoch': 0.02}                  
{'loss': 1.6317, 'grad_norm': 0.6506398916244507, 'learning_rate': 0.00019834682860998652, 'num_tokens': 284290.0, 'mean_token_accuracy': 0.6257665097713471, 'epoch': 0.03}                
{'loss': 1.5889, 'grad_norm': 0.7434213161468506, 'learning_rate': 0.0001980094466936572, 'num_tokens': 338761.0, 'mean_token_accuracy': 0.6391594052314759, 'epoch': 0.03}                 
{'loss': 1.5915, 'grad_norm': 0.7964017987251282, 'learning_rate': 0.00019767206477732793, 'num_tokens': 397731.0, 'mean_token_accuracy': 0.6308055430650711, 'epoch': 0.04}                
{'loss': 1.4446, 'grad_norm': 1.0523793697357178, 'learning_rate': 0.00019733468286099867, 'num_tokens': 451526.0, 'mean_token_accuracy': 0.6549594551324844, 'epoch': 0.04}                
{'loss': 1.4834, 'grad_norm': 0.6530594229698181, 'learning_rate': 0.00019699730094466938, 'num_tokens': 505287.0, 'mean_token_accuracy': 0.6439991772174836, 'epoch': 0.05}                
{'loss': 1.5385, 'grad_norm': 0.5796452164649963, 'learning_rate': 0.0001966599190283401, 'num_tokens': 563276.0, 'mean_token_accuracy': 0.6417025059461594, 'epoch': 0.05}                 
{'loss': 1.3649, 'grad_norm': 0.6477728486061096, 'learning_rate': 0.00019632253711201081, 'num_tokens': 623927.0, 'mean_token_accuracy': 0.6682371526956559, 'epoch': 0.06} 

SFTTrainer参数

参数名称类型描述
modelUnion[str, PreTrainedModel]待训练的模型。可是huggingface模型ID、本地模型目录路径,或PreTrainedModel对象(仅支持因果语言模型);通过AutoModelForCausalLM.from_pretrained加载,支持args.model_init_kwargs参数。
args[SFTConfig](可选,默认None训练器配置。若为None,则使用默认配置。
data_collatorDataCollator(可选)从处理后的训练/评估数据集元素列表中构建批次的函数。默认使用自定义的DataCollatorForLanguageModeling
train_dataset[~datasets.Dataset] 或 [~datasets.IterableDataset]训练数据集。支持语言建模型和提示-补全型,样本格式可为标准文本(纯文本)或对话格式(结构化消息,如角色+内容);也支持已分词数据集(需含input_ids字段)。
eval_dataset[~datasets.Dataset]、[~datasets.IterableDataset] 或 dict[...]评估数据集。需满足与train_dataset相同的要求。
processing_classPreTrainedTokenizerBase 等(可选,默认None数据处理类(如分词器)。若为None,则通过AutoTokenizer.from_pretrained从模型名称加载。
callbacks列表([~transformers.TrainerCallback],可选,默认None自定义训练循环的回调列表。会添加到默认回调列表中;可通过remove_callback方法移除默认回调。
optimizerstuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR](可选,默认(None, None)包含优化器和调度器的元组。默认使用AdamW优化器和get_linear_schedule_with_warmup调度器(由args控制)。
optimizer_cls_and_kwargsTuple[Type[torch.optim.Optimizer], Dict[str, Any]](可选,默认None包含优化器类和关键字参数的元组。会覆盖args中的optimoptim_args;与optimizers参数不兼容,且无需提前将模型参数放到正确设备上。
preprocess_logits_for_metricsCallable[[torch.Tensor, torch.Tensor], torch.Tensor](可选,默认None评估步骤中缓存logits前的预处理函数。需接收logits和标签(可为None),返回处理后的logits;修改会反映在compute_metrics接收的预测结果中。
peft_config[~peft.PeftConfig](可选,默认None用于封装模型的PEFT配置(如LoRA)。若为None,则不封装模型。
formatting_funcOptional[Callable]分词前应用于数据集的格式化函数。会将数据集显式转换为语言建模类型。

SFTConfig 的参数
SFTConfig 继承自 transformers.TrainingArguments

作用

SFTConfigSFTTrainer 的配置类,用于设置监督微调(SFT)的各项参数,包括模型加载、数据预处理、训练策略等细节,简化训练配置流程。通过细化模型加载、数据处理(如打包、填充)、训练策略(如损失计算范围)等参数,让监督微调更灵活适配不同场景(如对话训练、指令微调),同时继承了 TrainingArguments 的通用训练配置(如批次大小、训练轮数等),无需重复定义。

说明

参数名称类型默认值描述
覆盖父类默认值的参数
learning_ratefloat2e-5AdamW 优化器的初始学习率(父类默认值不同)。
logging_stepsfloat10每多少步记录一次日志(可设为整数或 [0,1) 之间的比例,代表总步数的占比)。
gradient_checkpointingboolTrue是否启用梯度检查点(以稍慢的反向传播为代价节省显存,父类默认值为 False)。
bf16Optional[bool]None是否使用 bf16 混合精度训练。默认根据 fp16 自动设置(fp16 为 False 则默认启用)。
average_tokens_across_devicesboolTrue是否跨设备平均 tokens 数量(用于精确计算损失,适配多设备训练)。
控制模型的参数
model_init_kwargsOptional[dict]None加载模型时的关键字参数(当 model 为字符串时,传给 AutoModelForCausalLM.from_pretrained)。
chat_template_pathOptional[str]None模型聊天模板路径(可为分词器目录、Hugging Face 模型ID或 Jinja 模板文件),用于格式化对话。
控制数据预处理的参数
dataset_text_fieldstr“text”数据集中存储文本的列名(如数据集里文本存在 text 列则无需修改)。
dataset_kwargsOptional[dict]None数据集准备的可选参数(仅支持 skip_prepare_dataset 键,用于跳过数据集预处理)。
dataset_num_procOptional[int]None处理数据集的进程数(加速数据预处理)。
eos_tokenOptional[str]None序列结束符。默认使用处理类(如分词器)的 eos_token
pad_tokenOptional[str]None填充符。默认使用处理类的 pad_token,若不存在则 fallback 到 eos_token
max_lengthOptional[int]1024token 化后的最大序列长度,超过则从右侧截断;启用打包时,此值为固定块长度。
packingboolFalse是否将多个短序列打包成固定长度块(减少填充,提升效率),长度由 max_length 定义。
packing_strategystr“bfd”打包策略:"bfd"(最佳适配递减,默认)或 "wrapped"(包裹式)。
padding_freeboolFalse是否无填充训练(将批次序列扁平化为单个连续序列,减少填充开销),需配合 FlashAttention 使用。
pad_to_multiple_ofOptional[int]None序列填充到该值的倍数(如 8、16,提升硬件效率)。
eval_packingOptional[bool]None评估数据集是否启用打包,默认与 packing 保持一致。
控制训练的参数
completion_only_lossOptional[bool]None是否只计算“补全部分”的损失:
- 对“提示-补全”型数据集,默认只算补全部分;
- 对语言建模型数据集,默认算全序列。
assistant_only_lossboolFalse是否只计算“助手回复部分”的损失(仅支持对话型数据集)。
activation_offloadingboolFalse是否将激活值卸载到 CPU(进一步节省 GPU 显存)。
http://www.dtcms.com/a/325766.html

相关文章:

  • docker是什么以及镜像命令详解
  • ROS2学习(1)—基础概念及环境搭建
  • B 树与 B + 树解析与实现
  • 北斗水文环境监测站在水库的应用
  • Linux操作系统从入门到实战(二十)进程优先级
  • 【从零开始java学习|第一篇】java中的名词概念(JDK、JVM、JRE等等)
  • 15. xhr 对象如何发起一个请求
  • VSCode右键菜单消失,修复VSCode右键菜单
  • raid10 允许最多坏几块磁盘,如何修复阵列?
  • lesson35:数据库深度解析:从概念到MySQL实战学习指南
  • 如何使用 Watchtower 实现定时更新 docker 中的镜像并自动更新容器(附 schedule 的参数详细解释)
  • 升级 ChatGPT 提示“您的银行卡被拒绝了”或者“您的信用卡被拒绝了。请尝试用借记卡支付。“如何解决?
  • FPGA+护理:跨学科发展的探索(二)
  • CVPR 2025 | 即插即用,极简数据蒸馏,速度up20倍,GPU占用仅2G
  • 【数字图像处理系列笔记】Ch09:特征提取与表示
  • YOLOv8 训练报错:PyTorch 2.6+ 模型加载兼容性问题解决
  • GPT-5 现已上线 DigitalOcean Gradient™ AI 平台!
  • 数据大集网:精准获客新引擎,助力中小企业突破推广困局
  • UKB-GWAS资源更新
  • C++ 检测 IPv4 和 IPv6 地址合法性
  • 朝花夕拾(一)-------布尔掩码(Boolean Mask)是什么?
  • npm install报错~[master] npm install npm error code ERESOLVE npm err
  • Redis 数据倾斜
  • 触想定制化工业一体机化身渔业预警终端,守望渔船安全
  • 验证二叉搜索树
  • (Arxiv-2025)Phantom:通过跨模态对齐实现主体一致性视频生成
  • 如何安装 Git (windows/mac/linux)
  • CodeBuddy IDE完全食用手册:从安装到生产力爆发的技术流解剖
  • 训推一体 | 暴雨X8848 G6服务器 x Intel®Gaudi® 2E AI加速卡
  • Android Audio实战——获取活跃音频类型(十五)