# 模型量化(二):基于BERT的量化代码实战
1. 前言
前面我对量化的理论知识做了非常详细的介绍。并且模型三大压缩技术中的蒸馏、剪枝也做了非常详细的讲解,并附上了详细的实战代码。
量化的代码将会基于前面做的student模型来给出详细的代码。
- 模型:剪枝+蒸馏后的student模型
- 手断:不手写量化,用集成好的库来实现
2. 代码
2.1 代码
import os
import time os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from transformers import AutoTokenizer, set_seed, BitsAndBytesConfig
import torch.nn.functional as F
import torch
import torch.nn as nn
from transformers import AutoModel, BertConfig, BertPreTrainedModel, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput class MeanPooler(torch.nn.Module): def forward(self, hidden, mask): # 长度兜底:如果 hidden 和 mask 的序列长不同,裁成一致 if mask is not None and hidden.size(1) != mask.size(1): seq = min(hidden.size(1), mask.size(1)) hidden = hidden[:, :seq, :] mask = mask[:, :seq] # 只 unsqueeze 一次!并做 dtype/device 对齐 mask = mask.to(dtype=hidden.dtype, device=hidden.device).unsqueeze(-1) # [B,S,1] summed = (hidden * mask).sum(dim=1) # 半精度下更稳的 eps eps = torch.finfo(hidden.dtype).tiny if hidden.dtype.is_floating_point else 1e-6 denom = mask.sum(dim=1).clamp(min=eps) # [B,1] return summed / denom # [B,H] class StudentHFModel(BertPreTrainedModel): """ - 训练时通常传入已构建/已剪枝好的 backbone、teacher.pool 的副本、teacher.head 的副本 - 推理时用 from_pretrained(dir) 时,不传 backbone/pool/head,模型会按 config 构建并 根据 config.student_pruned_heads 自动剪枝,保证能正确加载剪枝后的权重。 """ config_class = BertConfig base_model_prefix = "backbone" supports_sdpa = True supports_flash_attn_2 = Falsedef __init__(self, config, backbone=None, pool=None, head=None, **kwargs): try: if getattr(config, "attn_implementation", None) != "eager": config.attn_implementation = "eager" except Exception: pass super().__init__(config, attn_implementation="eager") if backbone is None: # 推理路径self.backbone = AutoModel.from_config(config) self._maybe_prune_from_config(config) else: # 训练路径self.backbone = backbone if pool is None: self.pool = MeanPooler() else: self.pool = pool # 分类头 if head is None: h = config.hidden_size num_labels = config.num_labels self.head = nn.Sequential( nn.Dropout(0.2), nn.Linear(h, h), nn.GELU(), nn.Dropout(0.2), nn.Linear(h, num_labels) ) else: self.head = head self.post_init() # HF 初始化未赋值权重 def _maybe_prune_from_config(self, config): """ 仅在“未传入外部 backbone”时(推理 from_pretrained)调用: 依据 config.student_pruned_heads 对每层执行 prune_heads,以复现训练时形状。 """ pruned = getattr(config, "student_pruned_heads", None) if not pruned: return # JSON 里字典 key 可能是 str,这里统一转 int pruned = {int(k): list(map(int, v)) for k, v in pruned.items()} # self.backbone 获取backbone encoder = getattr(self.backbone, "encoder", None) if encoder is None: return for l, heads in pruned.items(): if not heads: continue attn = encoder.layer[l].attention if hasattr(attn, "prune_heads"): attn.prune_heads(set(heads)) elif hasattr(attn, "self") and hasattr(attn.self, "prune_heads"): attn.self.prune_heads(set(heads)) def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs): outputs = self.backbone( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True, ) hidden = outputs.last_hidden_state sent = self.pool(hidden, attention_mask) logits = self.head(sent) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=getattr(outputs, "hidden_states", None), attentions=getattr(outputs, "attentions", None), ) DATASET = "banking77"
SEED = 42 OUT_DIR= f"runs/{DATASET}_prune_kd"
set_seed(SEED) @torch.no_grad()
def predict(texts, topk3, model, tok, cfg, to_dvice=False): batch = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=256) if to_dvice: inputs = {k: v.to(model.device) for k, v in batch.items()} else: inputs = batch out = model(**inputs) probs = torch.softmax(out["logits"], dim=-1) topv,topi = torch.topk(probs, topk3, dim=-1) id2label = cfg.id2label for i, text in enumerate(texts): top_l = [id2label[id] for id in topi[i].tolist()] top_v = topv[i].tolist() print("\nText:", text) print("Top-3:", (top_l, top_v)) if __name__ == "__main__": SAVE_DIR = f"{OUT_DIR}/best" bnb_8 = BitsAndBytesConfig(load_in_8bit=True) bnb_4bit = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", # NF4 bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 # Ampere+ 推荐 bfloat16 ) # tok cfg tok = AutoTokenizer.from_pretrained(SAVE_DIR, use_fast=True) cfg = AutoConfig.from_pretrained(SAVE_DIR) # test samples = [ "I lost my card yesterday. Can you block it?", # 丢卡 "Why was I charged twice for the same transaction?", # 重复扣款 "My contactless payment isn't working at the store.", # 闪付失效 "The transfer hasn't arrived to the recipient yet.", # 转账未到账 "How can I change my PIN code?", # 修改PIN "Is Apple Pay supported for my card?", # Apple Pay 支持 "The ATM swallowed my card. What should I do?", # ATM 吞卡 "Balance didn’t update after my bank transfer.", # 余额未更新 "I want to terminate my account.", # 注销账户 "What are the fees for international card payments?", # 汇率/手续费 "Can I get a spare physical card for my account?", # 备用卡 "Why do I need to verify my identity again?", # 身份验证 ] # test1 0.6G # model = StudentHFModel.from_pretrained(SAVE_DIR).eval().cuda() # predict(samples, 3, model=model, tok=tok, cfg=cfg, to_dvice=True) # test2 0.5G # model_8 = StudentHFModel.from_pretrained( # SAVE_DIR, # quantization_config=bnb_8, # device_map={"": 0}, # attn_implementation="eager" # ).eval() # predict(samples, 3, model=model_8, tok=tok, cfg=cfg) # test3 0.3G model_4 = StudentHFModel.from_pretrained( SAVE_DIR, quantization_config=bnb_4bit, device_map={"": 0}, attn_implementation="eager" ).eval() predict(samples, 3, model=model_4, tok=tok, cfg=cfg)
2.2 代码
分别执行 test1 test2 test3,显存分别占用 0.6G 0.5G 0.3G,可以看出量化的模型显存占用明显减少。
我所测试的模型比较小,在更大的模型上效果可能更加显著
3 代码简要说明
这里给你这份量化/推理脚本的精炼总结,便于回顾与交接👇
-
模型定义
-
StudentHFModel(BertPreTrainedModel)
:-
兼容 HF
from_pretrained
/save_pretrained
的学生模型外壳。 -
在推理路径下(未传入 backbone)可按
config.student_pruned_heads
自动剪头。 -
强制
attn_implementation="eager"
,避免 SDPA/FA2 兼容性问题。
-
-
MeanPooler
:-
对
last_hidden_state
做 mask 加权平均。 -
已处理 hidden 与 mask 序列长度不一致 的兜底 & 半精度安全的 eps。
-
-
-
量化配置
-
BitsAndBytesConfig(load_in_8bit=True)
:8-bit 权重量化。 -
BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
:典型 4-bit(NF4 + double quant + bfloat16 计算)。
-
-
加载与推理
-
从
SAVE_DIR = runs/banking77_prune_kd/best
恢复:-
(可选)FP 模型:
StudentHFModel.from_pretrained(SAVE_DIR).eval().cuda()
-
8-bit:
quantization_config=bnb_8, device_map={"":0}
-
4-bit:
quantization_config=bnb_4bit, device_map={"":0}
-
-
predict()
:- 分词 → 前向 →
softmax
→topk
→ 用cfg.id2label
映射成可读类别。
- 分词 → 前向 →
-
-
快速样例
- 给出 12 条银行客服意图文本;分别可用 FP/8bit/4bit 模型推理对比。
4 蒸馏、剪枝、量化小结
不必过于关注
4.1 总览:三件套各自解决什么
-
蒸馏(KD):把教师模型的“知识”迁移给更小的学生,主要保精度。
-
剪枝(Pruning):删掉不重要的参数/结构,主要减算量和延迟。
-
量化(Quantization):用低比特(INT8/4、FP16/BF16/NF4 等)表示权重/激活,主要减显存/内存并提速。
推荐顺序:先蒸馏得到合适的小架构 → 结构化剪枝(通道/头/层) → 最后上量化(PTQ / QAT)。
经验上,先缩小架构再做剪枝和量化,精度恢复更稳。
4.2 蒸馏(Knowledge Distillation)
常用损失组合
-
分类任务:
L = α * CE(student, hard_label) + (1-α) * T^2 * KL(softmax(s/T), softmax(t/T))
-
回归/序列预测(如 Audio→Motion):在上式基础上加中间特征/注意力/隐藏层 L2/L1,或直接对最终连续输出加 L1/L2 蒸馏(学生对齐教师输出的“形状/动态”)。
稳健超参
-
温度
T ∈ [2, 4]
(NLP 分类常用 2 或 3) -
权重
α ∈ [0.2, 0.5]
(类别多/数据少时,适当降低 α、提高 KL 权重) -
训练 3–10 个 epoch,与学习率调度绑总步数
简洁 PyTorch 片段(分类 KD)
def kd_loss(student_logits, teacher_logits, hard_labels, T=3.0, alpha=0.3):ce = torch.nn.functional.cross_entropy(student_logits, hard_labels)p_s = torch.log_softmax(student_logits / T, dim=-1)p_t = torch.softmax(teacher_logits / T, dim=-1)kl = torch.nn.functional.kl_div(p_s, p_t, reduction="batchmean") * (T * T)return alpha * ce + (1 - alpha) * kl
中间层蒸馏(可选)
- 匹配
hidden_states
、attentions
,层对层或多层汇总;TinyBERT/DistilBERT 路线都强调这一点,对小学生模型尤为有效。
4.3 剪枝(Pruning)
类型
-
非结构化:按权重幅值裁剪单参数(稀疏化)。对 GPU 推理益处有限,需特殊稀疏内核才明显提速。
-
结构化:裁剪通道/注意力头/整层,对延迟和显存最有用。
-
粒度建议:优先考虑注意力头剪枝、FFN 中间维度通道剪枝、整层(layer)蒸馏式减层。
策略与节奏
-
先做全局幅值初筛(给出大致稀疏目标,比如 20–40%)。
-
对 Transformer:对多头注意力 (num_heads) 和 FFN 隐层维度 (intermediate_size) 做结构化裁剪。
-
剪枝后微调 + KD 恢复(2–5 epoch),恢复幅度明显。
PyTorch 非结构化示例
import torch.nn.utils.prune as prune
for name, module in model.named_modules():if isinstance(module, torch.nn.Linear):prune.l1_unstructured(module, name='weight', amount=0.3)
# 将稀疏掩码固化
for m in model.modules():if isinstance(m, torch.nn.Linear) and hasattr(m, 'weight_orig'):prune.remove(m, 'weight')
结构化通道剪需要自定义或借助库(比如按通道 L1/L2 重要性排序后重构权重),实践中常与“减层蒸馏”结合更稳。
4.4 量化(Quantization)
选择指南
-
GPU 服务器:首选 FP16/BF16(几乎“无痛”),
torch.set_float32_matmul_precision('high'|'medium')
可进一步吃满 Tensor Core。 -
CPU/边缘:优先 INT8(PTQ 快,QAT 精度更高)。
-
大语言或大 Transformer 的推理省显存:权重仅量化(8-bit / 4-bit,如 bitsandbytes 8bit/NF4)非常常见。
PTQ vs QAT
-
PTQ(后训练量化):快;用几百~几千条校准样本跑一遍,若精度损失 ≤ 1~2%,通常就够了。
-
QAT(量化感知训练):在训练里插入假量化节点,精度恢复最佳,成本更高,适合对精度要求严苛的场景。
PyTorch(CPU 动态量化)
import torch
from torch.ao.quantization import quantize_dynamic
qmodel = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
Transformers + bitsandbytes(GPU 权重量化)
from transformers import AutoModelForSequenceClassification, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_8bit=True) # 或 load_in_4bit=True, bnb_4bit_quant_type="nf4"
model = AutoModelForSequenceClassification.from_pretrained("your-student",quantization_config=bnb_config,device_map="auto"
)
校准要点(PTQ)
-
覆盖真实分布:长度、类别频次、说话风格/姿态变化(对 Audio→Motion 很关键)。
-
LayerNorm/Softmax 附近量化最敏感,必要时对这些层 跳过量化 或仅权重量化。
-
典型校准规模:512~2k 条样本即可看到趋势。