人工智能概念:常用的模型压缩技术(剪枝、量化、知识蒸馏)
文章目录
- 一、模型压缩概述
- 1.1 什么是模型压缩?
- 1.2 为什么需要模型压缩?
- 1.3 四种主流模型压缩技术
- 二、模型量化:用低精度换高效能
- 2.1 量化的数学原理
- 2.2 量化计算示例
- 2.3 量化相关API详解
- 三、知识蒸馏:让小模型学会大模型的“智慧”
- 3.1 知识蒸馏的数学原理
- 3.2 蒸馏计算示例
- 3.3 知识蒸馏相关API详解
- 四、模型剪枝:移除冗余参数,保留核心能力
- 4.1 剪枝的数学原理
- 4.2 剪枝计算示例
- 4.3 模型剪枝相关API详解
- 4.4 剪枝注意事项
- 五、总结
一、模型压缩概述
1.1 什么是模型压缩?
模型压缩是一类通过减少模型参数数量、降低计算复杂度,从而在资源受限设备上高效部署深度学习模型的技术。其核心目标是在模型性能损失最小化的前提下,显著减小模型体积、降低内存占用、提升推理速度,以适应移动端、嵌入式设备等资源受限场景的需求。
1.2 为什么需要模型压缩?
随着Transformer等大模型的兴起,模型参数规模呈指数级增长。例如,原始BERT-base模型参数量约110M,推理时不仅占用大量内存,还需要较高的计算资源,难以直接部署在手机、摄像头等边缘设备上。此外,大模型的高推理延迟也无法满足实时性要求较高的业务场景(如实时推荐、语音助手)。
模型压缩的意义在于:
- 降低存储成本:减小模型文件大小,节省存储空间。
- 提升推理速度:减少计算量,降低延迟,满足实时性需求。
- 降低部署门槛:使模型能够在算力有限的边缘设备上运行。
- 减少能耗:降低推理过程中的能量消耗,适合移动设备。
1.3 四种主流模型压缩技术
目前,业界常用的模型压缩技术主要有四类:
技术名称 | 核心思想 | 特点 |
---|---|---|
剪枝(Pruning) | 移除模型中冗余的参数(如权重值接近0的连接),保留关键参数。 | 可分为结构化剪枝(移除整个通道/层)和非结构化剪枝(移除单个权重)。 |
量化(Quantization) | 用低精度数据类型(如int8)替代高精度类型(如float32)表示权重和激活值。 | 模型体积缩小4-8倍,推理速度提升2-4倍,实现简单。 |
知识蒸馏(Knowledge Distillation) | 让小模型(学生)学习大模型(教师)的“知识”(如软标签),模仿其行为。 | 保留大模型性能的同时,显著减小模型规模,适用于复杂模型压缩。 |
低秩因式分解(Low-rank Factorization) | 将高维权重矩阵分解为多个低维矩阵的乘积,减少参数数量。 | 适合线性层、卷积层等矩阵运算密集的模块,压缩率较高但实现较复杂。 |
二、模型量化:用低精度换高效能
2.1 量化的数学原理
量化的核心是将高精度浮点数(如float32)映射到低精度整数(如int8),核心公式涉及缩放因子(scale) 和零点(zero_point) 的计算。
-
基本定义
- 设浮点数范围为 [xmin,xmax][x_{\text{min}}, x_{\text{max}}][xmin,xmax],对应整数范围为 [qmin,qmax][q_{\text{min}}, q_{\text{max}}][qmin,qmax](如int8的[−128,127][-128, 127][−128,127])。
- 缩放因子 sss:控制浮点数到整数的比例映射。
- 零点 zzz:确保映射的偏移量(使0附近的浮点数能准确映射)。
-
核心公式
s=xmax−xminqmax−qmin(1)s = \frac{x_{\text{max}} - x_{\text{min}}}{q_{\text{max}} - q_{\text{min}}} \tag{1} s=qmax−qminxmax−xmin(1)
z=qmin−round(xmins)(2)z = q_{\text{min}} - \text{round}\left(\frac{x_{\text{min}}}{s}\right) \tag{2} z=qmin−round(sxmin)(2)
q=clip(round(xs+z),qmin,qmax)(3)q = \text{clip}\left(\text{round}\left(\frac{x}{s} + z\right), q_{\text{min}}, q_{\text{max}}\right) \tag{3} q=clip(round(sx+z),qmin,qmax)(3)- 式(1):计算缩放因子,将浮点数范围映射到整数范围。
- 式(2):计算零点,确保xminx_{\text{min}}xmin能映射到qminq_{\text{min}}qmin。
- 式(3):将浮点数xxx量化为整数qqq,并裁剪到整数范围内。
-
反量化公式(推理时还原)
xrecon=s⋅(q−z)(4)x_{\text{recon}} = s \cdot (q - z) \tag{4} xrecon=s⋅(q−z)(4)
2.2 量化计算示例
以float32到int8的量化为例,假设某层权重的浮点数范围为 [−1.2,3.6][-1.2, 3.6][−1.2,3.6],计算过程如下:
步骤1:确定范围
- 浮点数:xmin=−1.2x_{\text{min}} = -1.2xmin=−1.2,xmax=3.6x_{\text{max}} = 3.6xmax=3.6
- int8整数:qmin=−128q_{\text{min}} = -128qmin=−128,qmax=127q_{\text{max}} = 127qmax=127,范围长度 127−(−128)=255127 - (-128) = 255127−(−128)=255
步骤2:计算缩放因子sss
s=3.6−(−1.2)255=4.8255≈0.0188s = \frac{3.6 - (-1.2)}{255} = \frac{4.8}{255} \approx 0.0188 s=2553.6−(−1.2)=2554.8≈0.0188
步骤3:计算零点zzz
z=−128−round(−1.20.0188)=−128−round(−63.83)=−128+64=−64z = -128 - \text{round}\left(\frac{-1.2}{0.0188}\right) = -128 - \text{round}(-63.83) = -128 + 64 = -64 z=−128−round(0.0188−1.2)=−128−round(−63.83)=−128+64=−64
步骤4:量化单个浮点数
例如量化 x=0.5x = 0.5x=0.5:
q=round(0.50.0188+(−64))=round(26.59−64)=round(−37.41)=−37q = \text{round}\left(\frac{0.5}{0.0188} + (-64)\right) = \text{round}(26.59 - 64) = \text{round}(-37.41) = -37 q=round(0.01880.5+(−64))=round(26.59−64)=round(−37.41)=−37
- 量化结果:q=−37q = -37q=−37(在int8范围内)。
步骤5:反量化验证
xrecon=0.0188×(−37−(−64))=0.0188×27≈0.5076≈0.5x_{\text{recon}} = 0.0188 \times (-37 - (-64)) = 0.0188 \times 27 \approx 0.5076 \approx 0.5 xrecon=0.0188×(−37−(−64))=0.0188×27≈0.5076≈0.5
- 误差:∣0.5076−0.5∣=0.0076|0.5076 - 0.5| = 0.0076∣0.5076−0.5∣=0.0076,精度损失较小。
2.3 量化相关API详解
- PyTorch量化API
API名称 | 功能描述 | 关键参数说明 | 适用场景 |
---|---|---|---|
torch.quantization.quantize_dynamic | 动态量化模型,推理时实时计算scale和zero_point | - model :待量化模型- qconfig_spec :指定需量化的层类型(如{torch.nn.Linear} )- dtype :目标数据类型(如torch.qint8 ) | 快速部署、动态输入场景 |
torch.quantization.prepare | 为静态量化准备模型(插入观测器) | - model :待准备模型- qconfig :量化配置(如torch.quantization.get_default_qconfig('fbgemm') ) | 静态量化(需校准数据) |
torch.quantization.convert | 将准备好的模型转换为量化模型 | - model :经prepare 处理的模型 | 静态量化(精度更高) |
torch.quantization.QConfig | 定义量化配置(如激活和权重的量化方式) | - activation :激活量化方式(如FakeQuantize.with_args(observer=MovingAverageMinMaxObserver) )- weight :权重量化方式 | 自定义量化策略 |
- TensorFlow量化API
API名称 | 功能描述 | 关键参数说明 |
---|---|---|
tf.quantization.quantize | 对张量进行量化(支持动态范围量化) | - input :待量化张量- min_range /max_range :输入范围- T :目标类型(如tf.int8 ) |
tf.keras.layers.experimental.QuantizationAwareTraining | 量化感知训练(模拟量化过程,提升量化后精度) | - input_shape :输入形状- num_bits :量化位数 |
- ONNX Runtime量化API
API名称 | 功能描述 | 关键参数说明 |
---|---|---|
onnxruntime.quantization.quantize_dynamic | 动态量化ONNX模型 | - input_model :输入ONNX模型路径- output_model :输出量化模型路径- op_types_to_quantize :需量化的算子类型(如['MatMul', 'Add'] ) |
- 量化注意事项
- 动态量化适合CPU端部署,GPU量化建议使用TensorRT的INT8校准工具。
- 量化对模型精度的影响与任务相关:图像分类通常比目标检测更耐量化,文本分类比NER更耐量化。
- 混合精度量化(如部分层用float16,部分用int8)可在精度和速度间取得更好平衡。
三、知识蒸馏:让小模型学会大模型的“智慧”
3.1 知识蒸馏的数学原理
知识蒸馏的核心是通过KL散度衡量学生模型与教师模型的输出差异,结合硬标签损失优化学生模型。
-
软标签生成
教师模型的logits经过温度TTT调整后生成软标签:
pi=exp(zi/T)∑jexp(zj/T)(5)p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} \tag{5} pi=∑jexp(zj/T)exp(zi/T)(5)- ziz_izi:教师模型对第iii类的logits输出。
- TTT:温度参数(T>1T>1T>1使分布更平滑,保留更多知识)。
-
KL散度损失(软标签损失)
衡量学生软标签qqq与教师软标签ppp的差异:
LKL=∑ipilog(piqi)(6)L_{\text{KL}} = \sum_i p_i \log\left(\frac{p_i}{q_i}\right) \tag{6} LKL=i∑pilog(qipi)(6)- 当T=1T=1T=1时,KL散度退化为交叉熵损失。
-
总损失函数
Ltotal=α⋅LKL+(1−α)⋅LCE(7)L_{\text{total}} = \alpha \cdot L_{\text{KL}} + (1-\alpha) \cdot L_{\text{CE}} \tag{7} Ltotal=α⋅LKL+(1−α)⋅LCE(7)- LCEL_{\text{CE}}LCE:学生模型与真实标签的交叉熵损失(硬标签损失)。
- α\alphaα:软标签损失的权重(通常取0.5-0.9)。
3.2 蒸馏计算示例
以三分类任务为例,演示损失计算过程:
步骤1:模型输出
- 教师模型logits:zteacher=[3.0,1.0,0.2]z_{\text{teacher}} = [3.0, 1.0, 0.2]zteacher=[3.0,1.0,0.2]
- 学生模型logits:zstudent=[2.5,0.8,0.1]z_{\text{student}} = [2.5, 0.8, 0.1]zstudent=[2.5,0.8,0.1]
- 真实标签:y=[1,0,0]y = [1, 0, 0]y=[1,0,0](第0类)
步骤2:生成软标签(T=2.0T=2.0T=2.0)
- 教师软标签:
p=[exp(3/2)∑,exp(1/2)∑,exp(0.2/2)∑]≈[0.721,0.215,0.064]p = \left[ \frac{\exp(3/2)}{\sum}, \frac{\exp(1/2)}{\sum}, \frac{\exp(0.2/2)}{\sum} \right] \approx [0.721, 0.215, 0.064] p=[∑exp(3/2),∑exp(1/2),∑exp(0.2/2)]≈[0.721,0.215,0.064] - 学生软标签:
q=[exp(2.5/2)∑,exp(0.8/2)∑,exp(0.1/2)∑]≈[0.659,0.257,0.084]q = \left[ \frac{\exp(2.5/2)}{\sum}, \frac{\exp(0.8/2)}{\sum}, \frac{\exp(0.1/2)}{\sum} \right] \approx [0.659, 0.257, 0.084] q=[∑exp(2.5/2),∑exp(0.8/2),∑exp(0.1/2)]≈[0.659,0.257,0.084]
步骤3:计算损失
- KL散度损失(LKLL_{\text{KL}}LKL):
LKL=0.721log(0.721/0.659)+0.215log(0.215/0.257)+0.064log(0.064/0.084)≈0.018L_{\text{KL}} = 0.721\log(0.721/0.659) + 0.215\log(0.215/0.257) + 0.064\log(0.064/0.084) \approx 0.018 LKL=0.721log(0.721/0.659)+0.215log(0.215/0.257)+0.064log(0.064/0.084)≈0.018 - 硬标签损失(LCEL_{\text{CE}}LCE):
LCE=−log(q0)≈−log(0.659)≈0.418L_{\text{CE}} = -\log(q_0) \approx -\log(0.659) \approx 0.418 LCE=−log(q0)≈−log(0.659)≈0.418 - 总损失(α=0.7\alpha=0.7α=0.7):
Ltotal=0.7×0.018+0.3×0.418≈0.138L_{\text{total}} = 0.7 \times 0.018 + 0.3 \times 0.418 \approx 0.138 Ltotal=0.7×0.018+0.3×0.418≈0.138
3.3 知识蒸馏相关API详解
- Hugging Face Transformers API
API名称/工具 | 功能描述 | 关键参数说明 |
---|---|---|
transformers.Trainer | 通过自定义损失函数实现蒸馏(教师模型固定,学生模型训练) | - model :学生模型- args :训练参数(如TrainingArguments )- compute_loss :自定义损失函数(融合KL散度和交叉熵) |
transformers.DistilBertForSequenceClassification | 预训练蒸馏模型(如DistilBERT,学生模型) | - 继承自PreTrainedModel ,可直接加载预训练权重(如distilbert-base-uncased ) |
- PyTorch蒸馏工具
API名称 | 功能描述 | 关键参数说明 |
---|---|---|
torch.nn.KLDivLoss | 计算KL散度损失(软标签损失) | - reduction :损失聚合方式(如'batchmean' )- log_target :是否目标为对数形式 |
torch.nn.CrossEntropyLoss | 计算交叉熵损失(硬标签损失) | - weight :类别权重- reduction :损失聚合方式 |
- 专用蒸馏库
库名称 | 功能描述 | 核心API示例 |
---|---|---|
HuggingFace/transformers 中的蒸馏工具 | 提供DistilBERT、DistilRoBERTa等蒸馏模型的训练逻辑 | from transformers import DistilBertTokenizer, DistilBertForSequenceClassification |
knowledge-distillation-pytorch | 轻量级蒸馏库(支持多种蒸馏策略) | from kd import KnowledgeDistillationLoss (融合KL散度和硬标签损失) |
四、模型剪枝:移除冗余参数,保留核心能力
4.1 剪枝的数学原理
剪枝通过评估参数重要性移除冗余权重,常用L1范数衡量重要性(值越小越冗余)。
-
L1范数重要性评估
对于权重矩阵W∈Rm×nW \in \mathbb{R}^{m \times n}W∈Rm×n,单个权重wijw_{ij}wij的重要性为:
I(wij)=∣wij∣(8)\mathcal{I}(w_{ij}) = |w_{ij}| \tag{8} I(wij)=∣wij∣(8) -
全局剪枝阈值计算
若剪枝比例为rrr,则阈值θ\thetaθ满足:
∑i,jI(∣wij∣<θ)总参数数=r(9)\frac{\sum_{i,j} \mathbb{I}(|w_{ij}| < \theta)}{\text{总参数数}} = r \tag{9} 总参数数∑i,jI(∣wij∣<θ)=r(9)- I(⋅)\mathbb{I}(\cdot)I(⋅)为指示函数,满足条件时取1。
4.2 剪枝计算示例
以3x3权重矩阵为例,剪枝30%的参数:
步骤1:原始权重矩阵
W=[0.1−0.020.05−0.30.010.20.03−0.040.08]W = \begin{bmatrix} 0.1 & -0.02 & 0.05 \\ -0.3 & 0.01 & 0.2 \\ 0.03 & -0.04 & 0.08 \end{bmatrix} W=0.1−0.30.03−0.020.01−0.040.050.20.08
步骤2:计算L1范数(重要性)
∣I(W)∣=[0.10.020.050.30.010.20.030.040.08]|\mathcal{I}(W)| = \begin{bmatrix} 0.1 & 0.02 & 0.05 \\ 0.3 & 0.01 & 0.2 \\ 0.03 & 0.04 & 0.08 \end{bmatrix} ∣I(W)∣=0.10.30.030.020.010.040.050.20.08
步骤3:排序并确定阈值
将所有权重按L1范数升序排列:0.01,0.02,0.03,0.04,0.05,0.08,0.1,0.2,0.30.01, 0.02, 0.03, 0.04, 0.05, 0.08, 0.1, 0.2, 0.30.01,0.02,0.03,0.04,0.05,0.08,0.1,0.2,0.3
总参数9个,剪枝30%即移除3个参数,阈值θ=0.03\theta=0.03θ=0.03(第3小的值)。
步骤4:剪枝后矩阵(小于θ\thetaθ的权重置0)
Wpruned=[0.100.05−0.300.20−0.040.08]W_{\text{pruned}} = \begin{bmatrix} 0.1 & 0 & 0.05 \\ -0.3 & 0 & 0.2 \\ 0 & -0.04 & 0.08 \end{bmatrix} Wpruned=0.1−0.3000−0.040.050.20.08
- 稀疏度:3/9=33.3%(接近目标30%)。
4.3 模型剪枝相关API详解
- PyTorch剪枝API
API名称 | 功能描述 | 关键参数说明 |
---|---|---|
torch.nn.utils.prune.l1_unstructured | 对单个模块进行L1非结构化剪枝 | - module :待剪枝模块(如model.bert.encoder.layer[0].attention.self.query )- name :待剪枝参数名(如'weight' )- amount :剪枝比例(如0.3 ) |
torch.nn.utils.prune.global_unstructured | 对多个模块进行全局非结构化剪枝(统一阈值) | - parameters :待剪枝参数列表(如[(module, 'weight')] )- pruning_method :剪枝方法(如prune.L1Unstructured )- amount :剪枝比例 |
torch.nn.utils.prune.remove | 永久移除剪枝掩码(将0值权重保留在参数中) | - module :已剪枝模块- name :剪枝参数名 |
torch.nn.utils.prune.ln_structured | 对模块进行结构化剪枝(如按通道剪枝) | - n :剪枝维度(如0 表示按输出通道剪枝)- amount :剪枝比例- pruning_method :重要性评估方法(如'l1_unstructured' ) |
- TensorFlow剪枝API
API名称 | 功能描述 | 关键参数说明 |
---|---|---|
tfmot.sparsity.keras.prune_low_magnitude | 对Keras模型进行 magnitude-based 剪枝 | - model :待剪枝模型- pruning_schedule :剪枝调度(如PolynomialDecay ) |
tfmot.sparsity.keras.PolynomialDecay | 定义剪枝比例随训练步数的变化策略 | - initial_sparsity :初始稀疏度- final_sparsity :目标稀疏度- num_steps :总步数 |
- 第三方剪枝工具
工具名称 | 功能描述 | 核心特点 |
---|---|---|
TorchPrune | 支持PyTorch模型的结构化和非结构化剪枝 | 提供剪枝后模型微调工具,支持可视化剪枝效果 |
PruneTorch | 轻量级剪枝库(支持Transformer、ResNet等主流模型) | 实现简单,适合快速验证剪枝效果 |
4.4 剪枝注意事项
- 非结构化剪枝生成稀疏矩阵,需硬件支持(如NVIDIA的Sparse Tensor Core)才能加速,否则可能变慢。
- 结构化剪枝(如按通道)生成密集矩阵,无需特殊硬件,但剪枝比例过高会导致精度大幅下降。
- 剪枝后需微调模型(fine-tuning),恢复因剪枝丢失的性能(通常微调3-5个epoch即可)。
五、总结
模型压缩技术通过数学原理与工程实现的结合,在精度与效率间取得平衡,其核心API为工业界部署提供了便捷工具:
- 量化:通过
torch.quantization.quantize_dynamic
等API实现低精度转换,适合追求极致部署效率的场景,API使用简单但需注意精度权衡。 - 蒸馏:基于
KLDivLoss
与CrossEntropyLoss
组合,或使用DistilBERT
等预训练蒸馏模型,适合需要保留高精度的小模型场景。 - 剪枝:通过
global_unstructured
等API移除冗余参数,适合对模型大小敏感且可接受一定部署复杂度的场景。
实际应用中,可组合多种技术(如“剪枝+量化”)进一步提升压缩效果,例如先剪枝移除30%冗余参数,再量化为int8,可在精度损失5%以内实现模型体积缩减80%以上,推动大模型在边缘设备的落地。