python transformers笔记(Trainer类)
Trainer类
Trainer是Hugging Face Transformers库中用于简化模型训练和评估的核心工具类。它封装了标准的训练循环(如批次处理、反向传播、优化器更新等),支持分布式训练、混合精度计算和自动日志记录,极大减少了重复代码。
通过Trainer,可以用极简代码实现从训练到部署的全流程。如需处理特定任务(如多模态、大模型训练),可以进一步扩展功能。
1、核心功能
(1)自动化训练循环:处理前向传播、损失计算、反向传播、优化器更新。
(2)分布式训练:开箱即用的多GPU/TPU训练(无需修改代码)。
(3)混合精度训练:支持FP16(NVIDIA GPU)和BF16(AMD/Intel GPU/TPU)。
(4)灵活的评估策略:按epoch/steps触发验证集评估。
(5)模型保存与恢复:自动保存检查点,支持从中断处恢复训练。
(6)丰富的回调系统:可插入自定义逻辑(如早停、学习率调整)。
2、核心方法
(1)train():启动训练
(2)evaluate():在验证集上评估模型
(3)predict():生成预测结果
(4)save_model():保存模型和分词器
(5)push_to_hub():上传模型到Hugging Face Hub
from transformers import Trainer, TrainingArgumentstrainer = Trainer(model=model, # 待训练的模型实例args=TrainingArguments(...), # 训练配置train_dataset=train_data, # 训练集(需实现__len__和__getitem__)eval_dataset=eval_data, # 验证集(可选)compute_metrics=compute_metrics, # 自定义指标计算函数data_collator=data_collator, # 动态批次填充(默认为DataCollatorWithPadding)tokenizer=tokenizer, # 用于日志记录和保存callbacks=[callback1, ...] # 自定义回调
)