当前位置: 首页 > news >正文

# 模型量化(二):基于BERT的量化代码实战

1. 前言

前面我对量化的理论知识做了非常详细的介绍。并且模型三大压缩技术中的蒸馏、剪枝也做了非常详细的讲解,并附上了详细的实战代码。
量化的代码将会基于前面做的student模型来给出详细的代码。

  • 模型:剪枝+蒸馏后的student模型
  • 手断:不手写量化,用集成好的库来实现

2. 代码

2.1 代码

import os  
import time  os.environ["CUDA_VISIBLE_DEVICES"] = "0"  
from transformers import AutoTokenizer, set_seed, BitsAndBytesConfig  
import torch.nn.functional as F  
import torch  
import torch.nn as nn  
from transformers import AutoModel, BertConfig, BertPreTrainedModel, AutoConfig  
from transformers.modeling_outputs import SequenceClassifierOutput  class MeanPooler(torch.nn.Module):  def forward(self, hidden, mask):  # 长度兜底:如果 hidden 和 mask 的序列长不同,裁成一致  if mask is not None and hidden.size(1) != mask.size(1):  seq = min(hidden.size(1), mask.size(1))  hidden = hidden[:, :seq, :]  mask = mask[:, :seq]  # 只 unsqueeze 一次!并做 dtype/device 对齐  mask = mask.to(dtype=hidden.dtype, device=hidden.device).unsqueeze(-1)  # [B,S,1]  summed = (hidden * mask).sum(dim=1)  # 半精度下更稳的 eps        eps = torch.finfo(hidden.dtype).tiny if hidden.dtype.is_floating_point else 1e-6  denom = mask.sum(dim=1).clamp(min=eps)  # [B,1]  return summed / denom                    # [B,H]  class StudentHFModel(BertPreTrainedModel):  """  - 训练时通常传入已构建/已剪枝好的 backbone、teacher.pool 的副本、teacher.head 的副本  - 推理时用 from_pretrained(dir) 时,不传 backbone/pool/head,模型会按 config 构建并  根据 config.student_pruned_heads 自动剪枝,保证能正确加载剪枝后的权重。  """    config_class = BertConfig  base_model_prefix = "backbone"  supports_sdpa = True             supports_flash_attn_2 = Falsedef __init__(self, config, backbone=None, pool=None, head=None, **kwargs):  try:  if getattr(config, "attn_implementation", None) != "eager":  config.attn_implementation = "eager"  except Exception:  pass  super().__init__(config, attn_implementation="eager")  if backbone is None:  # 推理路径self.backbone = AutoModel.from_config(config)  self._maybe_prune_from_config(config)  else:  # 训练路径self.backbone = backbone  if pool is None:  self.pool = MeanPooler()  else:  self.pool = pool  # 分类头 if head is None:  h = config.hidden_size  num_labels = config.num_labels  self.head = nn.Sequential(  nn.Dropout(0.2),  nn.Linear(h, h),  nn.GELU(),  nn.Dropout(0.2),  nn.Linear(h, num_labels)  )  else:  self.head = head  self.post_init()  # HF 初始化未赋值权重  def _maybe_prune_from_config(self, config):  """  仅在“未传入外部 backbone”时(推理 from_pretrained)调用:  依据 config.student_pruned_heads 对每层执行 prune_heads,以复现训练时形状。  """        pruned = getattr(config, "student_pruned_heads", None)  if not pruned:  return  # JSON 里字典 key 可能是 str,这里统一转 int        pruned = {int(k): list(map(int, v)) for k, v in pruned.items()}  # self.backbone 获取backbone  encoder = getattr(self.backbone, "encoder", None)  if encoder is None:  return  for l, heads in pruned.items():  if not heads:  continue  attn = encoder.layer[l].attention  if hasattr(attn, "prune_heads"):  attn.prune_heads(set(heads))  elif hasattr(attn, "self") and hasattr(attn.self, "prune_heads"):  attn.self.prune_heads(set(heads))  def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):  outputs = self.backbone(  input_ids=input_ids,  attention_mask=attention_mask,  token_type_ids=token_type_ids,  return_dict=True,  )  hidden = outputs.last_hidden_state  sent = self.pool(hidden, attention_mask)  logits = self.head(sent)  loss = None  if labels is not None:  loss = F.cross_entropy(logits, labels)  return SequenceClassifierOutput(  loss=loss,  logits=logits,  hidden_states=getattr(outputs, "hidden_states", None),  attentions=getattr(outputs, "attentions", None),  )  DATASET = "banking77"  
SEED = 42  OUT_DIR= f"runs/{DATASET}_prune_kd"  
set_seed(SEED)  @torch.no_grad()  
def predict(texts, topk3, model, tok, cfg, to_dvice=False):  batch = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)  if to_dvice:  inputs = {k: v.to(model.device) for k, v in batch.items()}  else:  inputs = batch  out = model(**inputs)  probs = torch.softmax(out["logits"], dim=-1)  topv,topi = torch.topk(probs, topk3, dim=-1)  id2label = cfg.id2label  for i, text in enumerate(texts):  top_l = [id2label[id] for id in topi[i].tolist()]  top_v = topv[i].tolist()  print("\nText:", text)  print("Top-3:", (top_l, top_v))  if __name__ == "__main__":  SAVE_DIR = f"{OUT_DIR}/best"  bnb_8 = BitsAndBytesConfig(load_in_8bit=True)  bnb_4bit = BitsAndBytesConfig(  load_in_4bit=True,  bnb_4bit_quant_type="nf4",         # NF4  bnb_4bit_use_double_quant=True,  bnb_4bit_compute_dtype=torch.bfloat16  # Ampere+ 推荐 bfloat16    )  # tok cfg  tok = AutoTokenizer.from_pretrained(SAVE_DIR, use_fast=True)  cfg = AutoConfig.from_pretrained(SAVE_DIR)  # test  samples = [  "I lost my card yesterday. Can you block it?",                      # 丢卡  "Why was I charged twice for the same transaction?",               # 重复扣款  "My contactless payment isn't working at the store.",              # 闪付失效  "The transfer hasn't arrived to the recipient yet.",               # 转账未到账  "How can I change my PIN code?",                                   # 修改PIN  "Is Apple Pay supported for my card?",                             # Apple Pay 支持  "The ATM swallowed my card. What should I do?",                    # ATM 吞卡  "Balance didn’t update after my bank transfer.",                   # 余额未更新  "I want to terminate my account.",                                 # 注销账户  "What are the fees for international card payments?",              # 汇率/手续费  "Can I get a spare physical card for my account?",                 # 备用卡  "Why do I need to verify my identity again?",                      # 身份验证  ]  # test1 0.6G  # model = StudentHFModel.from_pretrained(SAVE_DIR).eval().cuda()    # predict(samples, 3, model=model, tok=tok, cfg=cfg, to_dvice=True)  # test2 0.5G    # model_8 = StudentHFModel.from_pretrained(    #     SAVE_DIR,    #     quantization_config=bnb_8,    #     device_map={"": 0},    #     attn_implementation="eager"    # ).eval()    # predict(samples, 3, model=model_8, tok=tok, cfg=cfg)  # test3 0.3G    model_4 = StudentHFModel.from_pretrained(  SAVE_DIR,  quantization_config=bnb_4bit,  device_map={"": 0},  attn_implementation="eager"  ).eval()  predict(samples, 3, model=model_4, tok=tok, cfg=cfg)

2.2 代码

分别执行 test1 test2 test3,显存分别占用 0.6G 0.5G 0.3G,可以看出量化的模型显存占用明显减少。

我所测试的模型比较小,在更大的模型上效果可能更加显著

3 代码简要说明

这里给你这份量化/推理脚本的精炼总结,便于回顾与交接👇

  • 模型定义

    • StudentHFModel(BertPreTrainedModel):

      • 兼容 HF from_pretrained / save_pretrained 的学生模型外壳。

      • 在推理路径下(未传入 backbone)可按 config.student_pruned_heads 自动剪头。

      • 强制 attn_implementation="eager",避免 SDPA/FA2 兼容性问题。

    • MeanPooler:

      • last_hidden_statemask 加权平均

      • 已处理 hidden 与 mask 序列长度不一致 的兜底 & 半精度安全的 eps。

  • 量化配置

    • BitsAndBytesConfig(load_in_8bit=True):8-bit 权重量化。

    • BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16):典型 4-bit(NF4 + double quant + bfloat16 计算)。

  • 加载与推理

    • SAVE_DIR = runs/banking77_prune_kd/best 恢复:

      • (可选)FP 模型:StudentHFModel.from_pretrained(SAVE_DIR).eval().cuda()

      • 8-bit:quantization_config=bnb_8, device_map={"":0}

      • 4-bit:quantization_config=bnb_4bit, device_map={"":0}

    • predict()

      • 分词 → 前向 → softmaxtopk → 用 cfg.id2label 映射成可读类别。
  • 快速样例

    • 给出 12 条银行客服意图文本;分别可用 FP/8bit/4bit 模型推理对比。

4 蒸馏、剪枝、量化小结

不必过于关注

4.1 总览:三件套各自解决什么

  • 蒸馏(KD):把教师模型的“知识”迁移给更小的学生,主要保精度。

  • 剪枝(Pruning):删掉不重要的参数/结构,主要减算量和延迟。

  • 量化(Quantization):用低比特(INT8/4、FP16/BF16/NF4 等)表示权重/激活,主要减显存/内存并提速。

推荐顺序:先蒸馏得到合适的小架构 → 结构化剪枝(通道/头/层) → 最后上量化(PTQ / QAT)。
经验上,先缩小架构再做剪枝和量化,精度恢复更稳


4.2 蒸馏(Knowledge Distillation)

常用损失组合

  • 分类任务:L = α * CE(student, hard_label) + (1-α) * T^2 * KL(softmax(s/T), softmax(t/T))

  • 回归/序列预测(如 Audio→Motion):在上式基础上加中间特征/注意力/隐藏层 L2/L1,或直接对最终连续输出加 L1/L2 蒸馏(学生对齐教师输出的“形状/动态”)。

稳健超参

  • 温度 T ∈ [2, 4](NLP 分类常用 2 或 3)

  • 权重 α ∈ [0.2, 0.5](类别多/数据少时,适当降低 α、提高 KL 权重)

  • 训练 3–10 个 epoch,与学习率调度绑总步数

简洁 PyTorch 片段(分类 KD)

def kd_loss(student_logits, teacher_logits, hard_labels, T=3.0, alpha=0.3):ce = torch.nn.functional.cross_entropy(student_logits, hard_labels)p_s = torch.log_softmax(student_logits / T, dim=-1)p_t = torch.softmax(teacher_logits / T, dim=-1)kl = torch.nn.functional.kl_div(p_s, p_t, reduction="batchmean") * (T * T)return alpha * ce + (1 - alpha) * kl

中间层蒸馏(可选)

  • 匹配 hidden_statesattentions,层对层或多层汇总;TinyBERT/DistilBERT 路线都强调这一点,对小学生模型尤为有效

4.3 剪枝(Pruning)

类型

  • 非结构化:按权重幅值裁剪单参数(稀疏化)。对 GPU 推理益处有限,需特殊稀疏内核才明显提速。

  • 结构化:裁剪通道/注意力头/整层,对延迟和显存最有用

  • 粒度建议:优先考虑注意力头剪枝FFN 中间维度通道剪枝整层(layer)蒸馏式减层

策略与节奏

  1. 先做全局幅值初筛(给出大致稀疏目标,比如 20–40%)。

  2. 对 Transformer:对多头注意力 (num_heads)FFN 隐层维度 (intermediate_size) 做结构化裁剪。

  3. 剪枝后微调 + KD 恢复(2–5 epoch),恢复幅度明显。

PyTorch 非结构化示例

import torch.nn.utils.prune as prune
for name, module in model.named_modules():if isinstance(module, torch.nn.Linear):prune.l1_unstructured(module, name='weight', amount=0.3)
# 将稀疏掩码固化
for m in model.modules():if isinstance(m, torch.nn.Linear) and hasattr(m, 'weight_orig'):prune.remove(m, 'weight')

结构化通道剪需要自定义或借助库(比如按通道 L1/L2 重要性排序后重构权重),实践中常与“减层蒸馏”结合更稳。


4.4 量化(Quantization)

选择指南

  • GPU 服务器:首选 FP16/BF16(几乎“无痛”),torch.set_float32_matmul_precision('high'|'medium') 可进一步吃满 Tensor Core。

  • CPU/边缘:优先 INT8(PTQ 快,QAT 精度更高)。

  • 大语言或大 Transformer 的推理省显存:权重仅量化(8-bit / 4-bit,如 bitsandbytes 8bit/NF4)非常常见。

PTQ vs QAT

  • PTQ(后训练量化):快;用几百~几千条校准样本跑一遍,若精度损失 ≤ 1~2%,通常就够了。

  • QAT(量化感知训练):在训练里插入假量化节点,精度恢复最佳,成本更高,适合对精度要求严苛的场景。

PyTorch(CPU 动态量化)

import torch
from torch.ao.quantization import quantize_dynamic
qmodel = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

Transformers + bitsandbytes(GPU 权重量化)

from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_8bit=True) # 或 load_in_4bit=True, bnb_4bit_quant_type="nf4"
model = AutoModelForSequenceClassification.from_pretrained("your-student",quantization_config=bnb_config,device_map="auto"
)

校准要点(PTQ)

  • 覆盖真实分布:长度、类别频次、说话风格/姿态变化(对 Audio→Motion 很关键)。

  • LayerNorm/Softmax 附近量化最敏感,必要时对这些层 跳过量化 或仅权重量化。

  • 典型校准规模:512~2k 条样本即可看到趋势。


http://www.dtcms.com/a/515134.html

相关文章:

  • 网站没有备案会怎样资源网官网
  • 【C++:继承】面向对象编程精要:C++继承机制深度解析与最佳实践
  • Python访问者模式实战指南:从基础到高级应用
  • 《数组和函数的实践游戏---扫雷游戏(基础版附源码)》
  • 专门做网站的软件是网站着陆页怎么做
  • 南京专业网站制作公司如何申请免费网站空间
  • 【乌班图】远程连接(向日葵/ToDesk)显示成功却无桌面的问题解析与解决
  • 异或的应用
  • c++语法——字符串(10.23讲课)
  • AI大事记13:GPT 与 BERT 的范式之争(上)
  • wordpress安装后查看站点失败网站创建多少年了
  • 文件指针控制函数
  • 【JavaEE初阶】 多线程编程核心:解锁线程创建、方法与状态的创新实践密码
  • JavaEE初阶——HTTP/HTTPS 核心原理:从协议格式到加密传输
  • Linux 内存 get_user_pages_remote 函数
  • 【图像处理】图像滤波
  • CSS 列表详解
  • 建设工程规范下载网站商城网站开发的完整流程
  • 同德县网站建设公司海南网站建设及维护
  • 广西送变电建设公司网站深圳市建设工程造价站官网
  • 网站获取访问者qq号码专业的网页设计和网站制作公司
  • 网站建设费账务处理a站下载
  • 哈尔滨网站建设丿薇建立短语
  • 徐州seo网站推广网站开发 页面功能布局
  • 用extjs做的网站wps如何做网站
  • 青羊区建设局网站怎样入驻微信小程序
  • 网站标题几个字合适学生个人网页制作html代码
  • 网站设计规划信息技术教案枣庄三合一网站开发
  • 提升网站流量电子邮件免费注册
  • 广东住房和城乡建设厅官方网站运维工程师累吗