知识蒸馏 - 大语言模型知识蒸馏LLM-KD-Trainer 源码分析 KnowledgeDistillationTrainer类
知识蒸馏 - 大语言模型知识蒸馏LLM-KD-Trainer 源码分析 KnowledgeDistillationTrainer类
flyfish
代码抄自
https://github.com/shaoshengsong/KDTrainer
代码分析的是LLM-KD-Trainer/LLM-KD-Trainer.py
版本
Python 版本: 3.12.9
PyTorch: 2.6.0+cu124
Transformers: 4.55.0
PEFT: 0.15.2
PyYAML: 6.0.2
环境搭建好之后 执行 python LLM-KD-Trainer.py 纯绿色版
最好先看完基础知识
知识蒸馏 - 蒸的什么
知识蒸馏 - 通过引入温度参数T调整 Softmax 的输出
知识蒸馏 - 对数函数的单调性
知识蒸馏 - 信息量的公式为什么是对数
知识蒸馏 - 根据真实事件的真实概率分布对其进行编码
知识蒸馏 - 信息熵中的平均为什么是按概率加权的平均
知识蒸馏 - 自信息量是单个事件的信息量,而平均自信息量(即信息熵)是所有事件自信息量以其概率为权重的加权平均值
知识蒸馏 - 最小化KL散度与最小化交叉熵是完全等价的
知识蒸馏 - 基于KL散度的知识蒸馏 KL散度的方向
class KnowledgeDistillationTrainer(Trainer):"""知识蒸馏专用的自定义Trainer类,基于师生模型框架继承自Hugging Face的Trainer类,扩展了知识蒸馏功能,使小型"学生"模型能够学习大型"教师"模型的知识。通过KL散度损失衡量师生模型的输出差异,可选结合学生模型自身的交叉熵损失,提升蒸馏效果。"""def __init__(self,student_model: PreTrainedModel, # 待训练的学生模型(小型模型)teacher_model: PreTrainedModel, # 提供知识的教师模型(大型模型)distillation_config: Dict[str, Any], # 蒸馏配置参数字典use_entropy_loss: bool = False, # 是否结合学生的交叉熵损失**kwargs # 传递给父类Trainer的其他参数(如训练参数、数据集等)) -> None:"""初始化知识蒸馏TrainerArgs:student_model: 学生模型,通过蒸馏学习教师模型的知识teacher_model: 教师模型,固定参数作为知识来源distillation_config: 蒸馏相关配置,包含:- temperature: 温度参数,控制logits分布的平滑程度- padding_id: 填充token的ID,用于过滤填充部分的损失use_entropy_loss: 若为True,总损失=0.7*KL损失 + 0.3*学生交叉熵损失;若为False,总损失=纯KL损失**kwargs: 父类Trainer所需的参数(如training_args、train_dataset等)"""# 调用父类Trainer的初始化方法,传入学生模型和其他参数super().__init__(model=student_model,** kwargs)self.teacher_model = teacher_model # 保存教师模型self.teacher_model.eval() # 教师模型固定为评估模式(不更新参数,关闭dropout等)self.distillation_config = distillation_config # 保存蒸馏配置self.use_entropy_loss = use_entropy_loss # 保存损失组合策略# 打印蒸馏配置日志,方便调试logger.info(f"蒸馏配置:温度={distillation_config['temperature']},填充ID={distillation_config['padding_id']}")logger.info(f"损失策略:{'KL散度+学生交叉熵' if use_entropy_loss else '纯KL散度'}")@staticmethoddef compute_forward_kl_divergence(student_logits: torch.Tensor, # 学生模型的输出logits(未经过softmax)teacher_logits: torch.Tensor, # 教师模型的输出logits(未经过softmax)target_labels: torch.Tensor, # 标签张量,用于识别填充位置padding_id: int, # 填充token的ID,填充部分不计入损失reduction: str = "sum", # 损失聚合方式("sum"表示求和)temperature: float = 1.0, # 温度参数,控制分布平滑度(>0)) -> torch.Tensor:"""计算教师与学生输出分布之间的前向KL散度衡量学生模型的输出分布与教师模型的输出分布之间的差异,通过温度参数平滑分布,并忽略填充位置的损失(填充部分对模型学习无意义)。Args:student_logits: 学生模型输出的logits,形状为[batch_size, seq_len, vocab_size]teacher_logits: 教师模型输出的logits,形状为[batch_size, seq_len, vocab_size]target_labels: 标签张量,形状为[batch_size, seq_len],用于定位填充位置padding_id: 填充token的ID,对应位置的损失会被过滤reduction: 损失聚合方式(此处固定为"sum",对所有有效位置求和)temperature: 温度参数,值越大分布越平滑(通常设为1-10)Returns:计算得到的KL散度损失张量"""# 用温度参数缩放logits,控制分布的平滑程度(温度越高,分布越平缓)student_logits = student_logits / temperatureteacher_logits = teacher_logits / temperature# 计算缩放后logits的对数概率(log_softmax)student_log_probs = torch.log_softmax(student_logits, dim=-1, dtype=torch.float32)teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1, dtype=torch.float32)# 计算教师模型的概率分布(softmax)teacher_probs = torch.softmax(teacher_logits, dim=-1, dtype=torch.float32)# 计算KL散度:KL(教师分布 || 学生分布) = 教师概率 * (教师对数概率 - 学生对数概率)kl_divergence = teacher_probs * (teacher_log_probs - student_log_probs)kl_divergence = kl_divergence.sum(dim=-1) # 对词汇表维度求和,得到每个token的KL损失# 处理填充部分:忽略填充位置的损失if reduction == "sum":# 生成填充掩码:标签等于padding_id的位置为True(需要过滤)pad_mask = target_labels.eq(padding_id)# 将填充位置的损失设为0(不参与总损失计算)kl_divergence = kl_divergence.masked_fill_(pad_mask, 0.0)# 对所有有效位置的损失求和,得到最终KL损失kl_divergence = kl_divergence.sum()return kl_divergencedef compute_loss(self, model: PreTrainedModel, # 学生模型(由Trainer自动传入)inputs: Dict[str, torch.Tensor], # 输入数据字典,包含input_ids、attention_mask、labels等return_outputs: bool = False, # 是否返回损失+模型输出(True用于需要输出的场景)num_items_in_batch: Optional[int] = None # 批次中的样本数(未使用,兼容父类接口)) -> Tuple[torch.Tensor, Any] | torch.Tensor:"""计算知识蒸馏的总损失结合KL散度损失(师生差异)和学生模型自身的交叉熵损失(可选),得到总损失。Args:model: 学生模型(实际为self.model,由Trainer传入)inputs: 输入数据字典,包含训练所需的所有张量return_outputs: 若为True,返回(总损失, 学生模型输出);否则仅返回总损失Returns:总损失张量,或(总损失, 学生模型输出)的元组"""# 1. 学生模型前向传播(计算交叉熵损失和logits)student_outputs = model(**inputs) # 学生模型输出,包含loss和logitsstudent_loss = student_outputs.loss # 学生模型自身的交叉熵损失(由模型内部计算)student_logits = student_outputs.logits # 学生模型的logits(用于计算KL损失)# 处理多卡训练场景:若损失是多维张量(如每个卡一个损失),则取平均转为标量if student_loss.dim() > 0: # 检查损失是否有维度(标量的dim=0)student_loss = student_loss.mean() # 多卡损失取平均,转为标量# 2. 教师模型前向传播(不计算梯度,避免更新教师参数)with torch.no_grad(): # 禁用梯度计算,节省内存并加速teacher_outputs = self.teacher_model(** inputs) # 教师模型输出teacher_logits = teacher_outputs.logits # 教师模型的logits(用于计算KL损失)# 3. 对齐师生模型的词汇表大小(若不同)if student_logits.shape[-1] != teacher_logits.shape[-1]:# 取较小的词汇表大小,截断较大的一方(避免维度不匹配)min_vocab = min(student_logits.shape[-1], teacher_logits.shape[-1])student_logits = student_logits[:, :, :min_vocab] # 截断学生logitsteacher_logits = teacher_logits[:, :, :min_vocab] # 截断教师logitslogger.debug(f"词汇表大小对齐至{min_vocab}(取师生模型中的较小值)")# 4. 计算KL散度损失(衡量师生输出差异)kl_loss = self.compute_forward_kl_divergence(student_logits=student_logits, # 学生logitsteacher_logits=teacher_logits, # 教师logitstarget_labels=inputs["labels"], # 标签(用于过滤填充)padding_id=self.distillation_config["padding_id"], # 填充IDtemperature=self.distillation_config["temperature"] # 温度参数)# 处理多卡场景的KL损失:若为多维张量则取平均转为标量if kl_loss.dim() > 0:kl_loss = kl_loss.mean()# 扩展维度为1D张量(避免多卡聚合时的标量警告)student_loss = student_loss.unsqueeze(0) # 标量→形状为[1]的张量kl_loss = kl_loss.unsqueeze(0) # 标量→形状为[1]的张量# 5. 计算总损失(根据策略组合KL损失和学生交叉熵损失)if self.use_entropy_loss:total_loss = 0.7 * kl_loss + 0.3 * student_loss # 加权组合else:total_loss = kl_loss # 纯KL损失# 再次扩展维度(确保多卡聚合时的兼容性)total_loss = total_loss.unsqueeze(0)# 6. 定期打印损失日志(每10步打印一次,避免日志刷屏)if self.state.global_step % 10 == 0:logger.info(f"步骤 {self.state.global_step} - "f"KL损失: {kl_loss.item():.4f} | " # .item()将张量转为Python数值f"学生交叉熵损失: {student_loss.item():.4f} | "f"总损失: {total_loss.item():.4f}")# 根据return_outputs决定返回格式return (total_loss, student_outputs) if return_outputs else total_loss
KnowledgeDistillationTrainer
类是基于 Trainer
扩展的自定义训练器,专门用于实现知识蒸馏功能。其核心作用是让小模型(学生模型)通过学习大模型(教师模型)的知识,在保持轻量化的同时逼近大模型的性能。
1. 目标
通过知识蒸馏技术,让“学生模型”(通常是较小的模型)模仿“教师模型”(通常是较大的预训练模型)的行为,将教师模型的知识迁移到学生模型中。
2. 细节
初始化配置:接收学生模型、教师模型、蒸馏参数(如温度系数、padding标识)等,初始化时将教师模型固定为评估模式(不更新参数),确保其作为“知识来源”的稳定性。
KL散度损失计算:通过 compute_forward_kl_divergence
方法,计算学生模型与教师模型输出分布的差异(KL散度)。具体来说:
- 对两者的输出(logits)进行温度缩放(平滑分布,让模仿更稳定);
- 忽略padding位置的损失(避免无效计算);
- 用KL散度衡量学生分布与教师分布的差距,作为蒸馏的核心损失。
- 总损失组合:在
compute_loss
方法中,将KL散度损失(学生模仿教师的损失)与学生模型自身的交叉熵损失(学生拟合真实标签的损失)结合:
若启用use_entropy_loss
,总损失为“70% KL散度损失 + 30% 学生交叉熵损失”;
若不启用,则仅用KL散度损失(学生完全以教师为学习目标)。 - 兼容基础训练功能:继承
Trainer
类的所有基础能力(如分布式训练、日志记录、模型保存等),仅定制损失计算逻辑,无需重复实现训练流程。
compute_loss
函数的功能
compute_loss函数是 Trainer
类中计算模型训练损失的方法,负责将模型输出与标签(或其他监督信号)结合,生成训练所需的损失值。
-
标签预处理
若存在标签平滑器(label_smoother
)或用户自定义损失函数(compute_loss_func
),且输入数据中包含labels
,则将labels
从输入中提取出来单独处理(避免模型前向传播时重复使用)。 -
模型前向传播
调用模型的forward
方法(outputs = model(**inputs)
),得到模型输出(包含 logits、隐藏状态等,具体结构因模型类型而异)。 -
损失计算逻辑
根据是否存在labels
及模型类型,选择不同的损失计算方式:- 若存在
labels
:- 优先使用用户自定义的
compute_loss_func
(如果设置); - 若为因果语言模型(如 GPT 类),使用标签平滑器并右移标签(避免泄露未来信息);
- 其他模型(如分类、序列标注)直接使用标签平滑器计算损失。
- 优先使用用户自定义的
- 若不存在
labels
:- 直接从模型输出中提取损失(适用于自监督学习等场景,模型自身会计算损失)。
- 若存在
-
分布式训练适配
在多设备分布式训练中,若开启了average_tokens_across_devices
,会对损失进行缩放(乘以进程数),确保损失计算的正确性。 -
返回结果
根据return_outputs
控制返回值:仅返回损失,或同时返回损失和模型输出(方便后续步骤如评估指标计算使用)。
KnowledgeDistillationTrainer
自定义 compute_loss
compute_loss
是 Trainer
类计算损失的核心入口,而知识蒸馏需要特殊的损失计算逻辑(引入教师模型、计算 KL 散度、组合多损失)。因此,KnowledgeDistillationTrainer
重写 compute_loss
函数,实现蒸馏功能。
原函数不支持知识蒸馏逻辑:
原生 compute_loss
仅处理“模型-标签”的损失(如交叉熵),而知识蒸馏需要额外计算“学生模型-教师模型”的损失(如 KL 散度),并将两者结合(如加权求和)。原生函数没有这种逻辑,必须通过重写扩展。
蒸馏需要引入教师模型:
知识蒸馏的核心是让学生模型模仿教师模型的输出,因此损失计算需要同时用到学生模型和教师模型的输出。原 compute_loss
仅接收学生模型(model
参数),无法访问教师模型,必须在重写时加入教师模型的前向传播和损失计算。
自定义损失组合方式:
蒸馏通常需要组合“学生-标签”的交叉熵损失和“学生-教师”的 KL 散度损失,这种组合逻辑是原生函数不具备的,必须通过重写实现。
Trainer的成员函数training_step里面会执行compute_loss
with self.compute_loss_context_manager():loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
软化分布KL散度
代码中实现的KL散度(Kullback-Leibler Divergence)公式是带温度参数的教师-学生概率分布差异度量,具体对应知识蒸馏中常用的“软化分布KL散度”
公式定义
设教师模型的输出logits为 TTT,学生模型的输出logits为 SSS,温度参数为 τ\tauτ(通常 τ≥1\tau \geq 1τ≥1,用于软化概率分布),则:
-
首先对教师和学生的logits进行温度软化:
教师的软化概率分布:Pτ=softmax(Tτ)P_\tau = \text{softmax}\left(\frac{T}{\tau}\right)Pτ=softmax(τT)
学生的软化概率分布:Qτ=softmax(Sτ)Q_\tau = \text{softmax}\left(\frac{S}{\tau}\right)Qτ=softmax(τS) -
计算KL散度(衡量 PτP_\tauPτ 与 QτQ_\tauQτ 的差异):
KL(Pτ∥Qτ)=∑xPτ(x)⋅(logPτ(x)−logQτ(x))\text{KL}(P_\tau \parallel Q_\tau) = \sum_{x} P_\tau(x) \cdot \left( \log P_\tau(x) - \log Q_\tau(x) \right) KL(Pτ∥Qτ)=x∑Pτ(x)⋅(logPτ(x)−logQτ(x))
代码对应关系
在 compute_forward_kl_divergence
方法中,上述公式的实现步骤如下:
- 对logits除以温度:
student_logits = student_logits / temperature
,teacher_logits = teacher_logits / temperature
- 计算软化分布的对数概率:
- 学生的对数概率:
student_log_probs = log_softmax(student_logits)
(即 logQτ\log Q_\taulogQτ) - 教师的对数概率:
teacher_log_probs = log_softmax(teacher_logits)
(即 logPτ\log P_\taulogPτ)
- 学生的对数概率:
- 教师的软化概率:
teacher_probs = softmax(teacher_logits)
(即 PτP_\tauPτ) - 计算KL散度:
kl_divergence = teacher_probs * (teacher_log_probs - student_log_probs).sum(dim=-1)
,完全对应上述公式的数学表达。
温度 τ\tauτ 的作用:当 τ>1\tau > 1τ>1 时,softmax输出的概率分布更平缓(“软化”),学生模型能从教师模型中学习到更细粒度的类别间关系(不仅是最大概率类别)。
公式中KL散度的方向是 Pτ∥QτP_\tau \parallel Q_\tauPτ∥Qτ(以教师分布为“基准”),目标是让学生分布 QτQ_\tauQτ 尽可能接近教师分布 PτP_\tauPτ。
transformers库的 Trainer
类
Trainer
是 transformers
库中一个高度封装的训练工具类,核心作用是将深度学习模型的训练、评估过程标准化并自动化,隐藏了底层训练循环的复杂细节(如梯度计算、参数更新、分布式通信等),让开发者无需手动编写冗长的训练代码。
Trainer 封装了“几乎整个训练过程”
-
训练循环自动化
自动处理 epoch 迭代、batch 加载、模型前向传播(forward
)、损失计算(compute_loss
)、反向传播(backward
)、参数更新(optimizer.step
)、学习率调度(scheduler.step
)等核心步骤,无需手动编写for
循环。 -
分布式训练支持
自动适配单卡、多卡(数据并行)等训练环境,处理设备间的数据同步、梯度聚合等底层逻辑,开发者无需手动调用torch.distributed
相关接口。 -
训练配置管理
通过TrainingArguments
类接收各种训练参数(如批次大小、学习率、训练轮数、日志频率、保存策略等),统一管理训练过程的细节,无需手动配置优化器、调度器。 -
日志与监控
自动记录训练过程中的损失、学习率等指标,支持 TensorBoard、W&B 等工具可视化,同时生成训练日志(如步数、耗时、内存占用等)。 -
模型保存与断点续训
按配置自动保存模型 checkpoint(包括权重、优化器状态、训练参数等),支持从 checkpoint 恢复训练,无需手动处理状态保存逻辑。 -
评估与预测集成
内置evaluate()
方法用于在验证集上评估模型,predict()
方法用于对新数据生成预测结果,自动计算评估指标(需配合compute_metrics
函数)。
Trainer 的灵活性
Trainer
允许通过继承并重写关键方法实现自定义逻辑
- 重写
compute_loss
:自定义损失计算(如知识蒸馏的 KL 散度 + 交叉熵损失,如你之前实现的KnowledgeDistillationTrainer
); - 重写
training_step
/evaluation_step
:自定义单步训练/评估的逻辑(如加入正则化、特殊数据处理); - 通过
callbacks
机制:插入自定义钩子(如训练中途修改学习率、保存额外信息)。
一、参数(Args)
名称 | 类型 | 说明 |
---|---|---|
model | [PreTrainedModel ] 或 torch.nn.Module (可选) | 用于训练、评估或预测的模型。若未提供,必须传入 model_init 。💡 对 transformers 库的 PreTrainedModel 优化最佳,也支持自定义 torch.nn.Module (需与 transformers 模型工作方式一致)。 |
args | [TrainingArguments ](可选) | 训练参数配置。若未提供,默认使用基础实例,output_dir 设为当前目录下的 tmp_trainer 文件夹。 |
data_collator | DataCollator (可选) | 用于将 train_dataset 或 eval_dataset 的样本列表组合成批次的函数。默认逻辑:若未提供 processing_class ,使用 default_data_collator ;若 processing_class 是特征提取器或分词器,使用 DataCollatorWithPadding 。 |
train_dataset | 多种数据集类型(可选) | 训练数据集。若为 datasets.Dataset ,会自动移除模型 forward 方法不接受的列。⚠️ 分布式训练中,若为带随机化的 IterableDataset ,需有 generator 属性(统一随机种子)或 set_epoch() 方法(控制随机数生成器)。 |
eval_dataset | 多种数据集类型或字典(可选) | 评估数据集。若为 datasets.Dataset ,自动移除模型不接受的列;若为字典,会在每个子数据集上评估,并在指标名前加字典键作为前缀。 |
processing_class | 多种处理器类(可选) | 用于数据处理的类(如分词器、特征提取器)。会自动处理模型输入,并与模型一起保存,方便重启训练或复用模型。 ⚠️ 替代已弃用的 tokenizer 参数。 |
model_init | Callable[[], PreTrainedModel] (可选) | 实例化模型的函数。若提供,每次调用 train() 都会从该函数返回的新模型实例开始训练。可接收超参数试验对象(如 optuna),用于根据超参数选择不同模型架构。 |
compute_loss_func | Callable (可选) | 计算损失的函数。接收模型原始输出、标签、累积批次大小(batch_size * gradient_accumulation_steps ),返回损失。可参考 Trainer 默认损失函数实现。 |
compute_metrics | Callable[[EvalPrediction], Dict] (可选) | 评估时计算指标的函数。接收 EvalPrediction ,返回指标字典(键为指标名,值为指标值)。⚠️ 若 TrainingArguments.batch_eval_metrics=True ,需接受 compute_result 参数(用于最后一批次计算全局统计量)。 |
callbacks | 列表(TrainerCallback )(可选) | 自定义训练循环的回调列表。会添加到默认回调中,可通过 remove_callback 移除默认回调。 |
optimizers | 元组(优化器, 调度器)(可选,默认 (None, None) ) | 包含优化器和学习率调度器的元组。默认使用 AdamW 优化器和 get_linear_schedule_with_warmup 调度器(由 args 控制)。 |
optimizer_cls_and_kwargs | 元组(优化器类, 参数字典)(可选) | 包含优化器类和关键字参数的元组。覆盖 args 中的 optim 和 optim_args ,与 optimizers 不兼容。无需在初始化 Trainer 前将模型参数放到设备上(优于 optimizers )。 |
preprocess_logits_for_metrics | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] (可选) | 评估时缓存 logits 前的预处理函数。接收 logits 和标签,返回处理后的 logits,修改会影响 compute_metrics 的输入。⚠️ 若数据集无标签,第二个参数(标签)为 None 。 |
二、重要属性(Important attributes)
名称 | 说明 |
---|---|
model | 始终指向核心模型。若为 transformers 模型,是 PreTrainedModel 的子类。 |
model_wrapped | 始终指向最外层的模型(当原始模型被其他模块包装时)。用于前向传播。 例如:DeepSpeed 下,内部模型会被 DeepSpeed 和 DistributedDataParallel 包装,此时 model_wrapped 指向最外层;未包装时与 model 相同。 |
is_model_parallel | 模型是否启用模型并行模式(与数据并行不同,指模型层拆分到不同 GPU 上)。 |
place_model_on_device | 是否自动将模型放到设备上。若使用模型并行、DeepSpeed,或 TrainingArguments.place_model_on_device=False ,则为 False 。 |
is_in_train | 模型是否正在执行 train (例如,训练过程中调用 evaluate 时为 True )。 |