当前位置: 首页 > news >正文

知识蒸馏 - 大语言模型知识蒸馏LLM-KD-Trainer 源码分析 KnowledgeDistillationTrainer类

知识蒸馏 - 大语言模型知识蒸馏LLM-KD-Trainer 源码分析 KnowledgeDistillationTrainer类

flyfish

代码抄自
https://github.com/shaoshengsong/KDTrainer
代码分析的是LLM-KD-Trainer/LLM-KD-Trainer.py

版本
Python 版本: 3.12.9
PyTorch: 2.6.0+cu124
Transformers: 4.55.0
PEFT: 0.15.2
PyYAML: 6.0.2

环境搭建好之后 执行 python LLM-KD-Trainer.py 纯绿色版

最好先看完基础知识
知识蒸馏 - 蒸的什么

知识蒸馏 - 通过引入温度参数T调整 Softmax 的输出

知识蒸馏 - 对数函数的单调性

知识蒸馏 - 信息量的公式为什么是对数

知识蒸馏 - 根据真实事件的真实概率分布对其进行编码

知识蒸馏 - 信息熵中的平均为什么是按概率加权的平均

知识蒸馏 - 自信息量是单个事件的信息量,而平均自信息量(即信息熵)是所有事件自信息量以其概率为权重的加权平均值

知识蒸馏 - 最小化KL散度与最小化交叉熵是完全等价的

知识蒸馏 - 基于KL散度的知识蒸馏 KL散度的方向


class KnowledgeDistillationTrainer(Trainer):"""知识蒸馏专用的自定义Trainer类,基于师生模型框架继承自Hugging Face的Trainer类,扩展了知识蒸馏功能,使小型"学生"模型能够学习大型"教师"模型的知识。通过KL散度损失衡量师生模型的输出差异,可选结合学生模型自身的交叉熵损失,提升蒸馏效果。"""def __init__(self,student_model: PreTrainedModel,  # 待训练的学生模型(小型模型)teacher_model: PreTrainedModel,  # 提供知识的教师模型(大型模型)distillation_config: Dict[str, Any],  # 蒸馏配置参数字典use_entropy_loss: bool = False,  # 是否结合学生的交叉熵损失**kwargs  # 传递给父类Trainer的其他参数(如训练参数、数据集等)) -> None:"""初始化知识蒸馏TrainerArgs:student_model: 学生模型,通过蒸馏学习教师模型的知识teacher_model: 教师模型,固定参数作为知识来源distillation_config: 蒸馏相关配置,包含:- temperature: 温度参数,控制logits分布的平滑程度- padding_id: 填充token的ID,用于过滤填充部分的损失use_entropy_loss: 若为True,总损失=0.7*KL损失 + 0.3*学生交叉熵损失;若为False,总损失=纯KL损失**kwargs: 父类Trainer所需的参数(如training_args、train_dataset等)"""# 调用父类Trainer的初始化方法,传入学生模型和其他参数super().__init__(model=student_model,** kwargs)self.teacher_model = teacher_model  # 保存教师模型self.teacher_model.eval()  # 教师模型固定为评估模式(不更新参数,关闭dropout等)self.distillation_config = distillation_config  # 保存蒸馏配置self.use_entropy_loss = use_entropy_loss  # 保存损失组合策略# 打印蒸馏配置日志,方便调试logger.info(f"蒸馏配置:温度={distillation_config['temperature']},填充ID={distillation_config['padding_id']}")logger.info(f"损失策略:{'KL散度+学生交叉熵' if use_entropy_loss else '纯KL散度'}")@staticmethoddef compute_forward_kl_divergence(student_logits: torch.Tensor,  # 学生模型的输出logits(未经过softmax)teacher_logits: torch.Tensor,  # 教师模型的输出logits(未经过softmax)target_labels: torch.Tensor,  # 标签张量,用于识别填充位置padding_id: int,  # 填充token的ID,填充部分不计入损失reduction: str = "sum",  # 损失聚合方式("sum"表示求和)temperature: float = 1.0,  # 温度参数,控制分布平滑度(>0)) -> torch.Tensor:"""计算教师与学生输出分布之间的前向KL散度衡量学生模型的输出分布与教师模型的输出分布之间的差异,通过温度参数平滑分布,并忽略填充位置的损失(填充部分对模型学习无意义)。Args:student_logits: 学生模型输出的logits,形状为[batch_size, seq_len, vocab_size]teacher_logits: 教师模型输出的logits,形状为[batch_size, seq_len, vocab_size]target_labels: 标签张量,形状为[batch_size, seq_len],用于定位填充位置padding_id: 填充token的ID,对应位置的损失会被过滤reduction: 损失聚合方式(此处固定为"sum",对所有有效位置求和)temperature: 温度参数,值越大分布越平滑(通常设为1-10)Returns:计算得到的KL散度损失张量"""# 用温度参数缩放logits,控制分布的平滑程度(温度越高,分布越平缓)student_logits = student_logits / temperatureteacher_logits = teacher_logits / temperature# 计算缩放后logits的对数概率(log_softmax)student_log_probs = torch.log_softmax(student_logits, dim=-1, dtype=torch.float32)teacher_log_probs = torch.log_softmax(teacher_logits, dim=-1, dtype=torch.float32)# 计算教师模型的概率分布(softmax)teacher_probs = torch.softmax(teacher_logits, dim=-1, dtype=torch.float32)# 计算KL散度:KL(教师分布 || 学生分布) = 教师概率 * (教师对数概率 - 学生对数概率)kl_divergence = teacher_probs * (teacher_log_probs - student_log_probs)kl_divergence = kl_divergence.sum(dim=-1)  # 对词汇表维度求和,得到每个token的KL损失# 处理填充部分:忽略填充位置的损失if reduction == "sum":# 生成填充掩码:标签等于padding_id的位置为True(需要过滤)pad_mask = target_labels.eq(padding_id)# 将填充位置的损失设为0(不参与总损失计算)kl_divergence = kl_divergence.masked_fill_(pad_mask, 0.0)# 对所有有效位置的损失求和,得到最终KL损失kl_divergence = kl_divergence.sum()return kl_divergencedef compute_loss(self, model: PreTrainedModel,  # 学生模型(由Trainer自动传入)inputs: Dict[str, torch.Tensor],  # 输入数据字典,包含input_ids、attention_mask、labels等return_outputs: bool = False,  # 是否返回损失+模型输出(True用于需要输出的场景)num_items_in_batch: Optional[int] = None  # 批次中的样本数(未使用,兼容父类接口)) -> Tuple[torch.Tensor, Any] | torch.Tensor:"""计算知识蒸馏的总损失结合KL散度损失(师生差异)和学生模型自身的交叉熵损失(可选),得到总损失。Args:model: 学生模型(实际为self.model,由Trainer传入)inputs: 输入数据字典,包含训练所需的所有张量return_outputs: 若为True,返回(总损失, 学生模型输出);否则仅返回总损失Returns:总损失张量,或(总损失, 学生模型输出)的元组"""# 1. 学生模型前向传播(计算交叉熵损失和logits)student_outputs = model(**inputs)  # 学生模型输出,包含loss和logitsstudent_loss = student_outputs.loss  # 学生模型自身的交叉熵损失(由模型内部计算)student_logits = student_outputs.logits  # 学生模型的logits(用于计算KL损失)# 处理多卡训练场景:若损失是多维张量(如每个卡一个损失),则取平均转为标量if student_loss.dim() > 0:  # 检查损失是否有维度(标量的dim=0)student_loss = student_loss.mean()  # 多卡损失取平均,转为标量# 2. 教师模型前向传播(不计算梯度,避免更新教师参数)with torch.no_grad():  # 禁用梯度计算,节省内存并加速teacher_outputs = self.teacher_model(** inputs)  # 教师模型输出teacher_logits = teacher_outputs.logits  # 教师模型的logits(用于计算KL损失)# 3. 对齐师生模型的词汇表大小(若不同)if student_logits.shape[-1] != teacher_logits.shape[-1]:# 取较小的词汇表大小,截断较大的一方(避免维度不匹配)min_vocab = min(student_logits.shape[-1], teacher_logits.shape[-1])student_logits = student_logits[:, :, :min_vocab]  # 截断学生logitsteacher_logits = teacher_logits[:, :, :min_vocab]  # 截断教师logitslogger.debug(f"词汇表大小对齐至{min_vocab}(取师生模型中的较小值)")# 4. 计算KL散度损失(衡量师生输出差异)kl_loss = self.compute_forward_kl_divergence(student_logits=student_logits,  # 学生logitsteacher_logits=teacher_logits,  # 教师logitstarget_labels=inputs["labels"],  # 标签(用于过滤填充)padding_id=self.distillation_config["padding_id"],  # 填充IDtemperature=self.distillation_config["temperature"]  # 温度参数)# 处理多卡场景的KL损失:若为多维张量则取平均转为标量if kl_loss.dim() > 0:kl_loss = kl_loss.mean()# 扩展维度为1D张量(避免多卡聚合时的标量警告)student_loss = student_loss.unsqueeze(0)  # 标量→形状为[1]的张量kl_loss = kl_loss.unsqueeze(0)  # 标量→形状为[1]的张量# 5. 计算总损失(根据策略组合KL损失和学生交叉熵损失)if self.use_entropy_loss:total_loss = 0.7 * kl_loss + 0.3 * student_loss  # 加权组合else:total_loss = kl_loss  # 纯KL损失# 再次扩展维度(确保多卡聚合时的兼容性)total_loss = total_loss.unsqueeze(0)# 6. 定期打印损失日志(每10步打印一次,避免日志刷屏)if self.state.global_step % 10 == 0:logger.info(f"步骤 {self.state.global_step} - "f"KL损失: {kl_loss.item():.4f} | "  # .item()将张量转为Python数值f"学生交叉熵损失: {student_loss.item():.4f} | "f"总损失: {total_loss.item():.4f}")# 根据return_outputs决定返回格式return (total_loss, student_outputs) if return_outputs else total_loss

KnowledgeDistillationTrainer 类是基于 Trainer 扩展的自定义训练器,专门用于实现知识蒸馏功能。其核心作用是让小模型(学生模型)通过学习大模型(教师模型)的知识,在保持轻量化的同时逼近大模型的性能。

1. 目标

通过知识蒸馏技术,让“学生模型”(通常是较小的模型)模仿“教师模型”(通常是较大的预训练模型)的行为,将教师模型的知识迁移到学生模型中。

2. 细节

初始化配置:接收学生模型、教师模型、蒸馏参数(如温度系数、padding标识)等,初始化时将教师模型固定为评估模式(不更新参数),确保其作为“知识来源”的稳定性。
KL散度损失计算:通过 compute_forward_kl_divergence 方法,计算学生模型与教师模型输出分布的差异(KL散度)。具体来说:

  • 对两者的输出(logits)进行温度缩放(平滑分布,让模仿更稳定);
  • 忽略padding位置的损失(避免无效计算);
  • 用KL散度衡量学生分布与教师分布的差距,作为蒸馏的核心损失。
  • 总损失组合:在 compute_loss 方法中,将KL散度损失(学生模仿教师的损失)与学生模型自身的交叉熵损失(学生拟合真实标签的损失)结合:
    若启用 use_entropy_loss,总损失为“70% KL散度损失 + 30% 学生交叉熵损失”;
    若不启用,则仅用KL散度损失(学生完全以教师为学习目标)。
  • 兼容基础训练功能:继承 Trainer 类的所有基础能力(如分布式训练、日志记录、模型保存等),仅定制损失计算逻辑,无需重复实现训练流程。

compute_loss 函数的功能

compute_loss函数是 Trainer 类中计算模型训练损失的方法,负责将模型输出与标签(或其他监督信号)结合,生成训练所需的损失值。

  1. 标签预处理
    若存在标签平滑器(label_smoother)或用户自定义损失函数(compute_loss_func),且输入数据中包含 labels,则将 labels 从输入中提取出来单独处理(避免模型前向传播时重复使用)。

  2. 模型前向传播
    调用模型的 forward 方法(outputs = model(**inputs)),得到模型输出(包含 logits、隐藏状态等,具体结构因模型类型而异)。

  3. 损失计算逻辑
    根据是否存在 labels 及模型类型,选择不同的损失计算方式:

    • 若存在 labels
      • 优先使用用户自定义的 compute_loss_func(如果设置);
      • 若为因果语言模型(如 GPT 类),使用标签平滑器并右移标签(避免泄露未来信息);
      • 其他模型(如分类、序列标注)直接使用标签平滑器计算损失。
    • 若不存在 labels
      • 直接从模型输出中提取损失(适用于自监督学习等场景,模型自身会计算损失)。
  4. 分布式训练适配
    在多设备分布式训练中,若开启了 average_tokens_across_devices,会对损失进行缩放(乘以进程数),确保损失计算的正确性。

  5. 返回结果
    根据 return_outputs 控制返回值:仅返回损失,或同时返回损失和模型输出(方便后续步骤如评估指标计算使用)。

KnowledgeDistillationTrainer 自定义 compute_loss

compute_lossTrainer 类计算损失的核心入口,而知识蒸馏需要特殊的损失计算逻辑(引入教师模型、计算 KL 散度、组合多损失)。因此,KnowledgeDistillationTrainer 重写 compute_loss 函数,实现蒸馏功能。

原函数不支持知识蒸馏逻辑
原生 compute_loss 仅处理“模型-标签”的损失(如交叉熵),而知识蒸馏需要额外计算“学生模型-教师模型”的损失(如 KL 散度),并将两者结合(如加权求和)。原生函数没有这种逻辑,必须通过重写扩展。

蒸馏需要引入教师模型
知识蒸馏的核心是让学生模型模仿教师模型的输出,因此损失计算需要同时用到学生模型和教师模型的输出。原 compute_loss 仅接收学生模型(model 参数),无法访问教师模型,必须在重写时加入教师模型的前向传播和损失计算。

自定义损失组合方式
蒸馏通常需要组合“学生-标签”的交叉熵损失和“学生-教师”的 KL 散度损失,这种组合逻辑是原生函数不具备的,必须通过重写实现。

Trainer的成员函数training_step里面会执行compute_loss

with self.compute_loss_context_manager():loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

软化分布KL散度

代码中实现的KL散度(Kullback-Leibler Divergence)公式是带温度参数的教师-学生概率分布差异度量,具体对应知识蒸馏中常用的“软化分布KL散度”

公式定义

设教师模型的输出logits为 TTT,学生模型的输出logits为 SSS,温度参数为 τ\tauτ(通常 τ≥1\tau \geq 1τ1,用于软化概率分布),则:

  1. 首先对教师和学生的logits进行温度软化:
    教师的软化概率分布:Pτ=softmax(Tτ)P_\tau = \text{softmax}\left(\frac{T}{\tau}\right)Pτ=softmax(τT)
    学生的软化概率分布:Qτ=softmax(Sτ)Q_\tau = \text{softmax}\left(\frac{S}{\tau}\right)Qτ=softmax(τS)

  2. 计算KL散度(衡量 PτP_\tauPτQτQ_\tauQτ 的差异):
    KL(Pτ∥Qτ)=∑xPτ(x)⋅(log⁡Pτ(x)−log⁡Qτ(x))\text{KL}(P_\tau \parallel Q_\tau) = \sum_{x} P_\tau(x) \cdot \left( \log P_\tau(x) - \log Q_\tau(x) \right) KL(PτQτ)=xPτ(x)(logPτ(x)logQτ(x))

代码对应关系

compute_forward_kl_divergence 方法中,上述公式的实现步骤如下:

  1. 对logits除以温度:student_logits = student_logits / temperatureteacher_logits = teacher_logits / temperature
  2. 计算软化分布的对数概率:
    • 学生的对数概率:student_log_probs = log_softmax(student_logits)(即 log⁡Qτ\log Q_\taulogQτ
    • 教师的对数概率:teacher_log_probs = log_softmax(teacher_logits)(即 log⁡Pτ\log P_\taulogPτ
  3. 教师的软化概率:teacher_probs = softmax(teacher_logits)(即 PτP_\tauPτ
  4. 计算KL散度:kl_divergence = teacher_probs * (teacher_log_probs - student_log_probs).sum(dim=-1),完全对应上述公式的数学表达。

温度 τ\tauτ 的作用:当 τ>1\tau > 1τ>1 时,softmax输出的概率分布更平缓(“软化”),学生模型能从教师模型中学习到更细粒度的类别间关系(不仅是最大概率类别)。
公式中KL散度的方向是 Pτ∥QτP_\tau \parallel Q_\tauPτQτ(以教师分布为“基准”),目标是让学生分布 QτQ_\tauQτ 尽可能接近教师分布 PτP_\tauPτ

transformers库的 Trainer
Trainertransformers 库中一个高度封装的训练工具类,核心作用是将深度学习模型的训练、评估过程标准化并自动化,隐藏了底层训练循环的复杂细节(如梯度计算、参数更新、分布式通信等),让开发者无需手动编写冗长的训练代码。

Trainer 封装了“几乎整个训练过程”

  1. 训练循环自动化
    自动处理 epoch 迭代、batch 加载、模型前向传播(forward)、损失计算(compute_loss)、反向传播(backward)、参数更新(optimizer.step)、学习率调度(scheduler.step)等核心步骤,无需手动编写 for 循环。

  2. 分布式训练支持
    自动适配单卡、多卡(数据并行)等训练环境,处理设备间的数据同步、梯度聚合等底层逻辑,开发者无需手动调用 torch.distributed 相关接口。

  3. 训练配置管理
    通过 TrainingArguments 类接收各种训练参数(如批次大小、学习率、训练轮数、日志频率、保存策略等),统一管理训练过程的细节,无需手动配置优化器、调度器。

  4. 日志与监控
    自动记录训练过程中的损失、学习率等指标,支持 TensorBoard、W&B 等工具可视化,同时生成训练日志(如步数、耗时、内存占用等)。

  5. 模型保存与断点续训
    按配置自动保存模型 checkpoint(包括权重、优化器状态、训练参数等),支持从 checkpoint 恢复训练,无需手动处理状态保存逻辑。

  6. 评估与预测集成
    内置 evaluate() 方法用于在验证集上评估模型,predict() 方法用于对新数据生成预测结果,自动计算评估指标(需配合 compute_metrics 函数)。

Trainer 的灵活性

Trainer 允许通过继承并重写关键方法实现自定义逻辑

  • 重写 compute_loss:自定义损失计算(如知识蒸馏的 KL 散度 + 交叉熵损失,如你之前实现的 KnowledgeDistillationTrainer);
  • 重写 training_step/evaluation_step:自定义单步训练/评估的逻辑(如加入正则化、特殊数据处理);
  • 通过 callbacks 机制:插入自定义钩子(如训练中途修改学习率、保存额外信息)。

一、参数(Args)

名称类型说明
model[PreTrainedModel] 或 torch.nn.Module(可选)用于训练、评估或预测的模型。若未提供,必须传入 model_init
💡 对 transformers 库的 PreTrainedModel 优化最佳,也支持自定义 torch.nn.Module(需与 transformers 模型工作方式一致)。
args[TrainingArguments](可选)训练参数配置。若未提供,默认使用基础实例,output_dir 设为当前目录下的 tmp_trainer 文件夹。
data_collatorDataCollator(可选)用于将 train_dataseteval_dataset 的样本列表组合成批次的函数。
默认逻辑:若未提供 processing_class,使用 default_data_collator;若 processing_class 是特征提取器或分词器,使用 DataCollatorWithPadding
train_dataset多种数据集类型(可选)训练数据集。若为 datasets.Dataset,会自动移除模型 forward 方法不接受的列。
⚠️ 分布式训练中,若为带随机化的 IterableDataset,需有 generator 属性(统一随机种子)或 set_epoch() 方法(控制随机数生成器)。
eval_dataset多种数据集类型或字典(可选)评估数据集。若为 datasets.Dataset,自动移除模型不接受的列;若为字典,会在每个子数据集上评估,并在指标名前加字典键作为前缀。
processing_class多种处理器类(可选)用于数据处理的类(如分词器、特征提取器)。会自动处理模型输入,并与模型一起保存,方便重启训练或复用模型。
⚠️ 替代已弃用的 tokenizer 参数。
model_initCallable[[], PreTrainedModel](可选)实例化模型的函数。若提供,每次调用 train() 都会从该函数返回的新模型实例开始训练。
可接收超参数试验对象(如 optuna),用于根据超参数选择不同模型架构。
compute_loss_funcCallable(可选)计算损失的函数。接收模型原始输出、标签、累积批次大小(batch_size * gradient_accumulation_steps),返回损失。可参考 Trainer 默认损失函数实现。
compute_metricsCallable[[EvalPrediction], Dict](可选)评估时计算指标的函数。接收 EvalPrediction,返回指标字典(键为指标名,值为指标值)。
⚠️ 若 TrainingArguments.batch_eval_metrics=True,需接受 compute_result 参数(用于最后一批次计算全局统计量)。
callbacks列表(TrainerCallback)(可选)自定义训练循环的回调列表。会添加到默认回调中,可通过 remove_callback 移除默认回调。
optimizers元组(优化器, 调度器)(可选,默认 (None, None)包含优化器和学习率调度器的元组。默认使用 AdamW 优化器和 get_linear_schedule_with_warmup 调度器(由 args 控制)。
optimizer_cls_and_kwargs元组(优化器类, 参数字典)(可选)包含优化器类和关键字参数的元组。覆盖 args 中的 optimoptim_args,与 optimizers 不兼容。
无需在初始化 Trainer 前将模型参数放到设备上(优于 optimizers)。
preprocess_logits_for_metricsCallable[[torch.Tensor, torch.Tensor], torch.Tensor](可选)评估时缓存 logits 前的预处理函数。接收 logits 和标签,返回处理后的 logits,修改会影响 compute_metrics 的输入。
⚠️ 若数据集无标签,第二个参数(标签)为 None

二、重要属性(Important attributes)

名称说明
model始终指向核心模型。若为 transformers 模型,是 PreTrainedModel 的子类。
model_wrapped始终指向最外层的模型(当原始模型被其他模块包装时)。用于前向传播。
例如:DeepSpeed 下,内部模型会被 DeepSpeedDistributedDataParallel 包装,此时 model_wrapped 指向最外层;未包装时与 model 相同。
is_model_parallel模型是否启用模型并行模式(与数据并行不同,指模型层拆分到不同 GPU 上)。
place_model_on_device是否自动将模型放到设备上。若使用模型并行、DeepSpeed,或 TrainingArguments.place_model_on_device=False,则为 False
is_in_train模型是否正在执行 train(例如,训练过程中调用 evaluate 时为 True)。
http://www.dtcms.com/a/323265.html

相关文章:

  • 【动态数据源】⭐️@DS注解实现项目中多数据源的配置
  • 【QT】常⽤控件详解(六)多元素控件 QListWidget Table Widget Tree Widget
  • 【Avalonia】无开发者账号使用iOS真机调试跨平台应用
  • C++四种类型转换
  • Tiger任务管理系统-12
  • SpringBoot学习日记(二)
  • Day38 Dataset和Dataloader类
  • Git 核心概念与操作全指南(含工作区、暂存区、版本库详解)
  • VisionMoE本地部署的创新设计:从架构演进到高效实现
  • python的format易混淆的细节
  • Java 实现企业级服务器资源监控系统(含 SSH 执行 + 邮件通知 + Excel 报表)
  • 欧拉公式的意义
  • 202506 电子学会青少年等级考试机器人六级器人理论真题
  • 通用AGI到来,记忆仍需要一点旧颜色
  • 【狂飙AGI】2025年上半年中文大模型综合性测评
  • [已解决]VSCode右键菜单消失恢复
  • 用户需求调研后的信息如何整理
  • 大语言模型提示工程与应用:LLMs文本生成与数据标注实践
  • 需求管理流程规范
  • 强化学习概论(1)
  • Android 锁屏图标的大小修改
  • android15哪些广播可以会走冷启动或者用于保活呢?
  • 探索Trae:使用Trae CN爬取 Gitbook 电子书
  • 【Doris】实时分析型数据库
  • 走遍美国5 The Right Magic 钓鱼秘决
  • 【Python 语法糖小火锅 · 第 3 涮】
  • 【RabbitMQ】高级特性—TTL、延迟队列详解
  • Java 中的编译与反编译:全面解析与实践指南
  • drippingblues靶机
  • 四边形(梯形、平行四边形、矩形、菱形和正方形)