当前位置: 首页 > news >正文

知识蒸馏(KD)详解三:基于BERT的知识蒸馏代码实战

知识蒸馏(KD)详解三:基于BERT的知识蒸馏代码实战

前言

数据集:数据集我们使用huggingface中的"banking77"
任务:拿BERT构建和训练一个意图识别模型
关于BERT模型和知识蒸馏理论知识在前两节已经详细介绍了,可以自己去查看


1. 教师模型训练

1.1 训练

教师模型采用BERT骨干网络 + 自定义分类头来训练,完整代码如下:

import os  
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  
import evaluate  
from datasets import load_dataset, DatasetDict  # huggingface  
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer, set_seed, DataCollatorWithPadding, AutoConfig  # huggingface  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torchinfo import summary  # ======================== 配置 ==================================
DATASET = "banking77" # 可选: "banking77" | "clinc150" | "massive_en" | "massive_zh"  
BACKBONE = "bert-base-uncased"  
MAX_LEN = 256  
SEED = 42  
LR = 2e-5  
EPOCHS = 10 
BSZ_TRAIN = 32  
BSZ_EVAL = 32  OUT_DIR= f"runs/{DATASET}_teacher_custom"  
set_seed(SEED)  # 1 ==================== 数据集加载 ====================
def load_intent_dataset(dataset_name):  if dataset_name == "banking77":  ds = load_dataset("banking77")  text_col, label_col = "text", "label"  num_labels = len(ds["train"].unique(label_col))  else:  raise ValueError("Unknown dataset")  if "validation" not in ds:  if "test" in ds:  ds = DatasetDict({"train": ds["train"], "validation": ds["test"]})  else:  split = ds.train_test_split(test_size=0.1, SEED=SEED)  ds = DatasetDict({"train": split["train"], "validation": split["test"]})  return ds, text_col, label_col, num_labels  # 2.======================= 数据预处理 ====================
def tok_fn(batch, tok, max_len, TEXT):  return tok(batch[TEXT], truncation=True, max_length=max_len)  # 4. ======================= 构建Teacher ====================  
class MeanPooler(torch.nn.Module):  """  hidden: 通常是 last_hidden_state,形状 [B, S, H](批大小/序列长/隐藏维)  mask: 通常是 attention_mask,形状 [B, S],有效 token 为 1,padding 为 0。  """    def forward(self, hidden, mask):  mask = mask.unsqueeze(-1) # [b,s,1]  summded = (hidden * mask).sum(dim=1)  denom = mask.sum(dim=1).clamp(min=1e-6)  return summded / denom  class TeacherBackboneHead(torch.nn.Module):  def __init__(self, backbone_name, num_labels, config=None):  super().__init__()  self.backbone = AutoModel.from_pretrained(backbone_name)  h = self.backbone.config.hidden_size  self.pool = MeanPooler()  self.head = nn.Sequential(  nn.Dropout(0.2),  nn.Linear(h, h),  nn.GELU(),  nn.Dropout(0.2),  nn.Linear(h, num_labels)  )  self.config = config  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):  assert input_ids is not None and attention_mask is not None and token_type_ids is not None, "input_ids, attention_mask, token_ids must be not None"  out = self.backbone(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True)  hidden = out.last_hidden_state  sent = self.pool(hidden, attention_mask) # (B, H)  logits = self.head(sent)  #(B,C)  loss = F.cross_entropy(logits, labels) if labels is not None else None  return {"loss":loss, "logits": logits}  if __name__ == "__main__":  cfg = AutoConfig.from_pretrained(BACKBONE)  ds, TEXT, LABEL, N_LABELS = load_intent_dataset(DATASET)  tok = AutoTokenizer.from_pretrained(BACKBONE, use_fast=True)  #  # ds: {'train': ['text', 'label'], 'validation': ['text', 'label']}    # ds_enc: {'train': ['label', 'input_ids', 'token_type_ids', 'attention_mask'], 'validation': ['label', 'input_ids', 'token_type_ids', 'attention_mask']}    ds_enc = ds.map(tok_fn, batched=True, remove_columns=[TEXT], fn_kwargs={"tok": tok, "max_len": MAX_LEN, "TEXT": TEXT})  """  ds_enc["train"][0] =>    {      "input_ids": [101, 2026, 3899, ... , 102],      "attention_mask": [1, 1, 1, ..., 1],      "token_type_ids": [0, 0, 0, ..., 0],      "labels": 17    }    """    labels = ds["train"].features["label"].names  id2label = {i: name for i, name in enumerate(labels)}  cfg.id2label = id2label  # 每一个mini batch padding  不在tok_fn中做padding,因为Trainer会自动padding  collator = DataCollatorWithPadding(tok)  # 3. ======================= 定义指标 ====================    acc = evaluate.load("accuracy")  # 预测正确的率  # 所有类别求宏平均F1,反应整体类别预测的准确性  f1 = evaluate.load("f1")  # F1 = 2 * (precision * recall) / (precision + recall), precision=TP / (TP + FP), recall=TP / (TP + FN)  def compute_metrics(eval_pred):  logits, labels = eval_pred  preds = logits.argmax(-1)  acc_result = acc.compute(predictions=preds, references=labels)  f1_result = f1.compute(predictions=preds, references=labels, average="macro")  return {"accuracy": acc_result["accuracy"], "f1": f1_result["f1"]}  # 5. ======================= 训练Teacher ====================  teacher = TeacherBackboneHead(BACKBONE, N_LABELS, config=cfg)  # 输出模型结构  examples = [ds_enc["train"][i] for i in range(2)]  batch_ex = collator(examples)  inputs_ex = {key: batch_ex[key] for key in ["input_ids", "attention_mask", "token_type_ids"]}  class Wrapper(nn.Module):  def __init__(self, core):  super().__init__()  self.core = core  def forward(self, input_ids, attention_mask=None, token_type_ids=None):  out = self.core(input_ids=input_ids,  attention_mask=attention_mask,  token_type_ids=token_type_ids)  return out["logits"]  # [B, num_labels]  wrapped = Wrapper(teacher)  summary(  wrapped,  input_data=inputs_ex,  depth=3,  col_names=("input_size","output_size","num_params","trainable"),  verbose=2  )  # args  args = TrainingArguments(  output_dir=OUT_DIR,  learning_rate=LR,  per_device_train_batch_size=BSZ_TRAIN,  per_device_eval_batch_size=BSZ_EVAL,  num_train_epochs=EPOCHS,  evaluation_strategy="epoch",  save_strategy="epoch",  logging_steps=50,  load_best_model_at_end=True,  metric_for_best_model="f1",  greater_is_better=True,  seed=SEED,  fp16=torch.cuda.is_available(),  report_to="none"  )  trainer = Trainer(  model=teacher,  args=args,  train_dataset=ds_enc["train"],  eval_dataset=ds_enc["validation"],  tokenizer=tok,  data_collator=collator, # 动态padding  compute_metrics=compute_metrics  )  trainer.train()  metrics = trainer.evaluate()  print(f"Final metrics: {metrics}")  # 保存(包含自定义头)  os.makedirs(f"{OUT_DIR}/best", exist_ok=True)  trainer.save_model(f"{OUT_DIR}/best")  tok.save_pretrained(f"{OUT_DIR}/best")  # trainer 中会自动保存tokenizer, 这一步可以不要  cfg =cfg.save_pretrained(f"{OUT_DIR}/best")  print(f"Saved teacher to {OUT_DIR}/best")

运行之后,首先我们可以看到模型结构:以下部分截图
在这里插入图片描述

我们可以看到我们这个教师模型的参数大概1亿1千万


模型训练完成后,如下:

在这里插入图片描述
在这里插入图片描述

模型最后的精度在0.92,f1也在0.92


1.2 测试

我们找几条文本来测试一下效果:
测试代码:

from pyarrow.lib import Mapping  
from transformers import AutoTokenizer, AutoConfig  
import os, torch, numpy as np  
from bert_kd import TeacherBackboneHead, MAX_LEN  
from safetensors.torch import load_file  @torch.no_grad()  
def predict(texts, topk3, model, tok, cfg):  batch = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN)  inputs = {k: v.to(device) for k, v in batch.items()}  out = model(**inputs)  probs = torch.softmax(out["logits"], dim=-1)  topv,topi = torch.topk(probs, topk3, dim=-1)  id2label = cfg.id2label  for i, text in enumerate(texts):  top_l = [id2label[id] for id in topi[i].tolist()]  top_v = topv[i].tolist()  print("\nText:", text)  print("Top-3:", (top_l, top_v))  if __name__ == "__main__":  TEACHER_DIR = r"runs/banking77_teacher_custom/best"  cfg_t = AutoConfig.from_pretrained(TEACHER_DIR)  tok_t = AutoTokenizer.from_pretrained(TEACHER_DIR)  model_t = TeacherBackboneHead(TEACHER_DIR, num_labels=77, config=cfg_t)  device = torch.device("cuda:0")  state_dict_t = load_file(os.path.join(TEACHER_DIR, "model.safetensors"))  model_t.load_state_dict(state_dict_t)  model_t = model_t.to(device)  model_t.eval()  samples = [  "I lost my card yesterday. Can you block it?",                      # 丢卡  "Why was I charged twice for the same transaction?",               # 重复扣款  "My contactless payment isn't working at the store.",              # 闪付失效  "The transfer hasn't arrived to the recipient yet.",               # 转账未到账  "How can I change my PIN code?",                                   # 修改PIN  "Is Apple Pay supported for my card?",                             # Apple Pay 支持  "The ATM swallowed my card. What should I do?",                    # ATM 吞卡  "Balance didn’t update after my bank transfer.",                   # 余额未更新  "I want to terminate my account.",                                 # 注销账户  "What are the fees for international card payments?",              # 汇率/手续费  "Can I get a spare physical card for my account?",                 # 备用卡  "Why do I need to verify my identity again?",                      # 身份验证  ]  predict(samples, 3, model=model_t, tok=tok_t, cfg=cfg_t)

测试结果:
在这里插入图片描述

从测试结果来看,测试的意图识别结果都是正确的。


2. 学生模型训练

2.1 代码

以下代码不仅训练了一个带有KD的学生模型,还训练了一个不带KD的学生模型,对比蒸馏与非蒸馏的效果。
教师模型采用1中训练的模型。

import os  
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  
import evaluate  
from datasets import load_dataset, DatasetDict  # huggingface  
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer, set_seed, DataCollatorWithPadding, AutoConfig, AutoModelForSequenceClassification  # huggingface  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torchinfo import summary  
from bert_kd import TeacherBackboneHead  
from safetensors.torch import load_file  
from dataclasses import dataclass  
from typing import Optional, Union, Callable, Dict, Any, List, Tuple  # ======================== 配置 ==================================
DATASET = "banking77"  
TEACHER_DIR = f"runs/{DATASET}_teacher_custom" # 训练好的teacher模型路径  
MAX_LEN = 256  
SEED = 42  
LR = 3e-4  
EPOCHS = 20  
BSZ_TRAIN = 32  
BSZ_EVAL = 32  STUDENT_CKPT = "huawei-noah/TinyBERT_General_4L_312D"  OUT_DIR= f"runs/{DATASET}_tinybert_kd"  
set_seed(SEED)  def load_intent_dataset(dataset_name):  if dataset_name == "banking77":  ds = load_dataset("banking77")  text_col, label_col = "text", "label"  num_labels = len(ds["train"].unique(label_col))  else:  raise ValueError("Unknown dataset")  if "validation" not in ds:  if "test" in ds:  ds = DatasetDict({"train": ds["train"], "validation": ds["test"]})  else:  split = ds.train_test_split(test_size=0.1, SEED=SEED)  ds = DatasetDict({"train": split["train"], "validation": split["test"]})  return ds, text_col, label_col, num_labels  def tok_fn(batch, tok, max_len, TEXT):  return tok(batch[TEXT], truncation=True, max_length=max_len)  def dual_tok_fn(batch, tok_s, tok_t,  max_len, TEXT):  enc_s = tok_s(batch[TEXT], truncation=True, max_length=max_len)  enc_t = tok_t(batch[TEXT], truncation=True, max_length=max_len)  out = {  "input_ids": enc_s["input_ids"],  "attention_mask": enc_s["attention_mask"],  "token_type_ids": enc_s["token_type_ids"],  "t_input_ids": enc_t["input_ids"],  "t_attention_mask": enc_t["attention_mask"],  "t_token_type_ids": enc_t["token_type_ids"],  }  return out  @dataclass  
class KDConfig:  T: float = 4.0  alpha: float = 0.7  mid_weight: float = 0.0   # 如需中间层蒸馏,设 >0 并看下方可选块  # 构建知识蒸馏训练trainer  
class DistillerTrainer(Trainer):  def __init__(self, *ar, teacher=None, kd=None, **kw):  super().__init__(*ar, **kw)  self.teacher = teacher  self.kd = kd  self.T = kd.T  self.alpha = kd.alpha  def compute_loss(self, model, inputs, return_outputs=False):  labels = inputs["labels"]  # students out  out_s = model(  input_ids=inputs["input_ids"],  attention_mask=inputs["attention_mask"],  token_type_ids=inputs["token_type_ids"],  return_dict=True  )  z_s = out_s.logits  # teacher out  with torch.no_grad():  out_t = self.teacher.backbone(  input_ids=inputs["t_input_ids"],  attention_mask=inputs["t_attention_mask"],  token_type_ids=inputs["t_token_type_ids"],  return_dict=True  )  hidden = out_t.last_hidden_state  sent = self.teacher.pool(hidden, inputs["t_attention_mask"]) # (B, H)  z_t = self.teacher.head(sent)  #(B,C)  loss_kd = F.kl_div(F.log_softmax(z_s/self.T, dim=-1), F.softmax(z_t/self.T, dim=-1), reduction="batchmean") * (self.T ** 2)  loss_ce = F.cross_entropy(z_s, labels)  loss = self.alpha * loss_kd + (1 - self.alpha) * loss_ce  return (loss, out_s) if return_outputs else loss  class DualDataClllatorWithPadding:  def __init__(self, tok_s, tok_t):  super().__init__()  self.collator_s = DataCollatorWithPadding(tok_s)  self.collator_t = DataCollatorWithPadding(tok_t)  self.keys_s = ["input_ids","attention_mask","token_type_ids", "label"]  self.keys_t = ["t_input_ids","t_attention_mask","t_token_type_ids"]  def __call__(self, features: List[Dict[str, Any]])-> Dict[str, Any]:  # features: list of dict  features_s = [{k:v for k,v in f.items() if k in self.keys_s} for f in features]  features_t = [{k[2:]:v for k,v in f.items() if k in self.keys_t} for f in features]  # 去掉t_前缀  batch_s = self.collator_s(features_s)  batch_t = self.collator_t(features_t)  # batch_s 中加入带t_的batch_t  for k,v in batch_t.items():  batch_s[f"t_{k}"] = v  return batch_s  if __name__ == "__main__":  # load teacher model  TEACHER_DIR = r"runs/banking77_teacher_custom/best"  cfg_t = AutoConfig.from_pretrained(TEACHER_DIR)  tok_t = AutoTokenizer.from_pretrained(TEACHER_DIR, use_fast=True)  model_t = TeacherBackboneHead(TEACHER_DIR, num_labels=77, config=cfg_t)  device = torch.device("cuda:0")  state_dict_t = load_file(os.path.join(TEACHER_DIR, "model.safetensors"))  model_t.load_state_dict(state_dict_t)  model_t = model_t.to(device)  model_t.eval()  # load dataset  ds, TEXT, LABEL, N_LABELS = load_intent_dataset(DATASET)  tok_s = AutoTokenizer.from_pretrained(STUDENT_CKPT, use_fast=True)  ds_enc_s = ds.map(tok_fn, batched=True, remove_columns=[TEXT], fn_kwargs={"tok": tok_s, "max_len": MAX_LEN, "TEXT": TEXT})  collator_s = DataCollatorWithPadding(tok_s)  collator_t = DataCollatorWithPadding(tok_t)  ds_enc = ds.map(dual_tok_fn, batched=True, remove_columns=[TEXT], fn_kwargs={"tok_s": tok_s, "tok_t": tok_t, "max_len": MAX_LEN, "TEXT": TEXT})  collator = DualDataClllatorWithPadding(tok_s, tok_t)  # 学生模型  student = AutoModelForSequenceClassification.from_pretrained(STUDENT_CKPT, num_labels=N_LABELS)  # cfg  cfg = AutoConfig.from_pretrained(STUDENT_CKPT, num_labels=N_LABELS)  labels = ds["train"].features["label"].names  id2label = {i: name for i, name in enumerate(labels)}  cfg.id2label = id2label  # acc f1  acc = evaluate.load("accuracy")  # 预测正确的率  f1 = evaluate.load("f1")  def compute_metrics(eval_pred):  logits, labels = eval_pred  preds = logits.argmax(-1)  acc_result = acc.compute(predictions=preds, references=labels)  f1_result = f1.compute(predictions=preds, references=labels, average="macro")  return {"accuracy": acc_result["accuracy"], "f1": f1_result["f1"]}  # 输出模型结构  examples = [ds_enc_s["train"][i] for i in range(2)]  batch_ex = collator_s(examples)  inputs_ex = {key: batch_ex[key] for key in ["input_ids", "attention_mask", "token_type_ids"]}  class Wrapper(nn.Module):  def __init__(self, core):  super().__init__()  self.core = core  def forward(self, input_ids, attention_mask=None, token_type_ids=None):  out = self.core(input_ids=input_ids,  attention_mask=attention_mask,  token_type_ids=token_type_ids)  return out["logits"]  # [B, num_labels]  wrapped = Wrapper(student)  summary(  wrapped,  input_data=inputs_ex,  depth=3,  col_names=("input_size","output_size","num_params","trainable"),  verbose=2  )  # kd  KD = KDConfig()  # args  args_ = TrainingArguments(  output_dir=OUT_DIR,  learning_rate=LR,  per_device_train_batch_size=BSZ_TRAIN,  per_device_eval_batch_size=BSZ_EVAL,  num_train_epochs=EPOCHS,  evaluation_strategy="epoch",  save_strategy="epoch",  logging_steps=500,  load_best_model_at_end=True,  metric_for_best_model="f1",  greater_is_better=True,  seed=SEED,  fp16=torch.cuda.is_available(),  report_to="none",  remove_unused_columns=False  )  # trainer  trainer = DistillerTrainer(  model=student,  args=args_,  train_dataset=ds_enc["train"],  eval_dataset=ds_enc["validation"],  tokenizer=tok_s,  data_collator=collator, # 动态padding  compute_metrics=compute_metrics,  teacher=model_t,  kd=KD  )  # train  trainer.train()  metrics = trainer.evaluate()  print(f"Final metrics for student: {metrics}")  # 保存(包含自定义头)  os.makedirs(f"{OUT_DIR}/best", exist_ok=True)  trainer.save_model(f"{OUT_DIR}/best")  tok_s.save_pretrained(f"{OUT_DIR}/best")  # trainer 中会自动保存tokenizer, 这一步可以不要  cfg.save_pretrained(f"{OUT_DIR}/best")  print(f"Saved student to {OUT_DIR}/best")  # 对比直接训练  student2 = AutoModelForSequenceClassification.from_pretrained(STUDENT_CKPT, num_labels=N_LABELS)  trainer2 = Trainer(  model=student2,  args=args_,  train_dataset=ds_enc_s["train"],  eval_dataset=ds_enc_s["validation"],  tokenizer=tok_s,  data_collator=collator_s, # 动态padding  compute_metrics=compute_metrics  )  # train  trainer2.train()  metrics2 = trainer2.evaluate()  print(f"Final metrics for student no kd: {metrics2}")  # 保存(包含自定义头)  os.makedirs(f"{OUT_DIR}_no_kd/best", exist_ok=True)  trainer2.save_model(f"{OUT_DIR}_no_kd/best")  tok_s.save_pretrained(f"{OUT_DIR}_no_kd/best")  # trainer 中会自动保存tokenizer, 这一步可以不要  cfg.save_pretrained(f"{OUT_DIR}_no_kd/best")  print(f"Saved student to {OUT_DIR}_no_kd/best")

2.2 结果

直接运行上面代码,首先可以看到student模型结构和参数。可发现其参数量比教师模型少了很多,大概只有教师模型的10%。
在这里插入图片描述

最后acc和f1的结果我们可以看到:
KD Student:
在这里插入图片描述


No KD Student:
在这里插入图片描述


结果:带KD的模型比不带KD的模型提升了4%左右。


结论

  1. 基于BERT 骨干网络的预训练模型,结合自定义的分类头,训练一个意图识别模型,需要很少的epoch就能完成,并且得到的效果还不错。说明BERT的预训练模型确实有着不错的语言特征提取能力。
  2. KD的效果确实能够在一个小模型上得到跟教师模型差不多的效果,并且由于不带KD直接训练的小模型。说明KD确实起了作用。
  3. 这里我们只用了基本的KD手断,比如中间层计算损失也没用,参数也没怎么去调整。
    4. KD的参数很大程度了影响了结果,这个参数也是试了好几次得到相对比较好的效果,最好是加上隐藏层损失,那样会更好一点,这个留给看这篇文章的人自己去尝试。
http://www.dtcms.com/a/391744.html

相关文章:

  • 数字化手术室品牌厂家——珠海全视通
  • Linux 冯诺依曼体系结构与进程理解
  • Git GitHub 个人账户创建及链接本地项目教程
  • Leetcode 20
  • 第五章:离家出走
  • RabbitMQ配置项
  • 用html5写一个时区时间查询器
  • deepseek认为明天CSP-J/S初赛的重点
  • 基于Vue的场景解决
  • 浅谈 Sui 的区块链隐私解决方案
  • ETF期权交易的基础知识是什么?
  • 连接管理模块的实现
  • AI 的耳朵在哪里?—— 语音识别
  • 微博舆情大数据实战项目 Python爬虫+SnowNLP情感+Vue可视化 全栈开发 大数据项目 机器学习✅
  • Dify笔记
  • 高精度维文OCR系统:基于深度学习驱动的实现路径、技术优势与挑战
  • 使用Python+Selenium做自动化测试
  • GESP C++ 三级 2025年6月真题解析
  • Linux系统多线程的互斥问题
  • Python 之监控服务器服务
  • el-select 多选增加全部选项
  • Day24 窗口操作
  • 5. Linux 文件系统基本管理
  • 【MySQL】GROUP BY详解与优化
  • 深度学习:DenseNet 稠密连接​ -- 缓解梯度消失
  • Linux DNS 子域授权实践
  • 团体程序设计天梯赛-练习集 L1-041 寻找250
  • mellanox网卡(ConnectX-7)开启SACK
  • 相机镜头靶面
  • 【语法进阶】gevent的使用与总结