当前位置: 首页 > news >正文

人工智能概念:常用的模型压缩技术(剪枝、量化、知识蒸馏)

文章目录

    • 一、模型压缩概述
      • 1.1 什么是模型压缩?
      • 1.2 为什么需要模型压缩?
      • 1.3 四种主流模型压缩技术
    • 二、模型量化:用低精度换高效能
      • 2.1 量化的数学原理
      • 2.2 量化计算示例
      • 2.3 量化相关API详解
    • 三、知识蒸馏:让小模型学会大模型的“智慧”
      • 3.1 知识蒸馏的数学原理
      • 3.2 蒸馏计算示例
      • 3.3 知识蒸馏相关API详解
    • 四、模型剪枝:移除冗余参数,保留核心能力
      • 4.1 剪枝的数学原理
      • 4.2 剪枝计算示例
      • 4.3 模型剪枝相关API详解
      • 4.4 剪枝注意事项
    • 五、总结


一、模型压缩概述

1.1 什么是模型压缩?

模型压缩是一类通过减少模型参数数量、降低计算复杂度,从而在资源受限设备上高效部署深度学习模型的技术。其核心目标是在模型性能损失最小化的前提下,显著减小模型体积、降低内存占用、提升推理速度,以适应移动端、嵌入式设备等资源受限场景的需求。

1.2 为什么需要模型压缩?

随着Transformer等大模型的兴起,模型参数规模呈指数级增长。例如,原始BERT-base模型参数量约110M,推理时不仅占用大量内存,还需要较高的计算资源,难以直接部署在手机、摄像头等边缘设备上。此外,大模型的高推理延迟也无法满足实时性要求较高的业务场景(如实时推荐、语音助手)。

模型压缩的意义在于:

  • 降低存储成本:减小模型文件大小,节省存储空间。
  • 提升推理速度:减少计算量,降低延迟,满足实时性需求。
  • 降低部署门槛:使模型能够在算力有限的边缘设备上运行。
  • 减少能耗:降低推理过程中的能量消耗,适合移动设备。

1.3 四种主流模型压缩技术

目前,业界常用的模型压缩技术主要有四类:

技术名称核心思想特点
剪枝(Pruning)移除模型中冗余的参数(如权重值接近0的连接),保留关键参数。可分为结构化剪枝(移除整个通道/层)和非结构化剪枝(移除单个权重)。
量化(Quantization)用低精度数据类型(如int8)替代高精度类型(如float32)表示权重和激活值。模型体积缩小4-8倍,推理速度提升2-4倍,实现简单。
知识蒸馏(Knowledge Distillation)让小模型(学生)学习大模型(教师)的“知识”(如软标签),模仿其行为。保留大模型性能的同时,显著减小模型规模,适用于复杂模型压缩。
低秩因式分解(Low-rank Factorization)将高维权重矩阵分解为多个低维矩阵的乘积,减少参数数量。适合线性层、卷积层等矩阵运算密集的模块,压缩率较高但实现较复杂。

二、模型量化:用低精度换高效能

在这里插入图片描述

2.1 量化的数学原理

量化的核心是将高精度浮点数(如float32)映射到低精度整数(如int8),核心公式涉及缩放因子(scale)零点(zero_point) 的计算。

  1. 基本定义

    • 设浮点数范围为 [xmin,xmax][x_{\text{min}}, x_{\text{max}}][xmin,xmax],对应整数范围为 [qmin,qmax][q_{\text{min}}, q_{\text{max}}][qmin,qmax](如int8的[−128,127][-128, 127][128,127])。
    • 缩放因子 sss:控制浮点数到整数的比例映射。
    • 零点 zzz:确保映射的偏移量(使0附近的浮点数能准确映射)。
  2. 核心公式
    s=xmax−xminqmax−qmin(1)s = \frac{x_{\text{max}} - x_{\text{min}}}{q_{\text{max}} - q_{\text{min}}} \tag{1} s=qmaxqminxmaxxmin(1)
    z=qmin−round(xmins)(2)z = q_{\text{min}} - \text{round}\left(\frac{x_{\text{min}}}{s}\right) \tag{2} z=qminround(sxmin)(2)
    q=clip(round(xs+z),qmin,qmax)(3)q = \text{clip}\left(\text{round}\left(\frac{x}{s} + z\right), q_{\text{min}}, q_{\text{max}}\right) \tag{3} q=clip(round(sx+z),qmin,qmax)(3)

    • 式(1):计算缩放因子,将浮点数范围映射到整数范围。
    • 式(2):计算零点,确保xminx_{\text{min}}xmin能映射到qminq_{\text{min}}qmin
    • 式(3):将浮点数xxx量化为整数qqq,并裁剪到整数范围内。
  3. 反量化公式(推理时还原)
    xrecon=s⋅(q−z)(4)x_{\text{recon}} = s \cdot (q - z) \tag{4} xrecon=s(qz)(4)

2.2 量化计算示例

以float32到int8的量化为例,假设某层权重的浮点数范围为 [−1.2,3.6][-1.2, 3.6][1.2,3.6],计算过程如下:

步骤1:确定范围

  • 浮点数:xmin=−1.2x_{\text{min}} = -1.2xmin=1.2xmax=3.6x_{\text{max}} = 3.6xmax=3.6
  • int8整数:qmin=−128q_{\text{min}} = -128qmin=128qmax=127q_{\text{max}} = 127qmax=127,范围长度 127−(−128)=255127 - (-128) = 255127(128)=255

步骤2:计算缩放因子sss
s=3.6−(−1.2)255=4.8255≈0.0188s = \frac{3.6 - (-1.2)}{255} = \frac{4.8}{255} \approx 0.0188 s=2553.6(1.2)=2554.80.0188

步骤3:计算零点zzz
z=−128−round(−1.20.0188)=−128−round(−63.83)=−128+64=−64z = -128 - \text{round}\left(\frac{-1.2}{0.0188}\right) = -128 - \text{round}(-63.83) = -128 + 64 = -64 z=128round(0.01881.2)=128round(63.83)=128+64=64

步骤4:量化单个浮点数
例如量化 x=0.5x = 0.5x=0.5
q=round(0.50.0188+(−64))=round(26.59−64)=round(−37.41)=−37q = \text{round}\left(\frac{0.5}{0.0188} + (-64)\right) = \text{round}(26.59 - 64) = \text{round}(-37.41) = -37 q=round(0.01880.5+(64))=round(26.5964)=round(37.41)=37

  • 量化结果:q=−37q = -37q=37(在int8范围内)。

步骤5:反量化验证
xrecon=0.0188×(−37−(−64))=0.0188×27≈0.5076≈0.5x_{\text{recon}} = 0.0188 \times (-37 - (-64)) = 0.0188 \times 27 \approx 0.5076 \approx 0.5 xrecon=0.0188×(37(64))=0.0188×270.50760.5

  • 误差:∣0.5076−0.5∣=0.0076|0.5076 - 0.5| = 0.0076∣0.50760.5∣=0.0076,精度损失较小。

2.3 量化相关API详解

  1. PyTorch量化API
API名称功能描述关键参数说明适用场景
torch.quantization.quantize_dynamic动态量化模型,推理时实时计算scale和zero_point- model:待量化模型
- qconfig_spec:指定需量化的层类型(如{torch.nn.Linear}
- dtype:目标数据类型(如torch.qint8
快速部署、动态输入场景
torch.quantization.prepare为静态量化准备模型(插入观测器)- model:待准备模型
- qconfig:量化配置(如torch.quantization.get_default_qconfig('fbgemm')
静态量化(需校准数据)
torch.quantization.convert将准备好的模型转换为量化模型- model:经prepare处理的模型静态量化(精度更高)
torch.quantization.QConfig定义量化配置(如激活和权重的量化方式)- activation:激活量化方式(如FakeQuantize.with_args(observer=MovingAverageMinMaxObserver)
- weight:权重量化方式
自定义量化策略
  1. TensorFlow量化API
API名称功能描述关键参数说明
tf.quantization.quantize对张量进行量化(支持动态范围量化)- input:待量化张量
- min_range/max_range:输入范围
- T:目标类型(如tf.int8
tf.keras.layers.experimental.QuantizationAwareTraining量化感知训练(模拟量化过程,提升量化后精度)- input_shape:输入形状
- num_bits:量化位数
  1. ONNX Runtime量化API
API名称功能描述关键参数说明
onnxruntime.quantization.quantize_dynamic动态量化ONNX模型- input_model:输入ONNX模型路径
- output_model:输出量化模型路径
- op_types_to_quantize:需量化的算子类型(如['MatMul', 'Add']
  1. 量化注意事项
    • 动态量化适合CPU端部署,GPU量化建议使用TensorRT的INT8校准工具。
    • 量化对模型精度的影响与任务相关:图像分类通常比目标检测更耐量化,文本分类比NER更耐量化。
    • 混合精度量化(如部分层用float16,部分用int8)可在精度和速度间取得更好平衡。

三、知识蒸馏:让小模型学会大模型的“智慧”

在这里插入图片描述

3.1 知识蒸馏的数学原理

知识蒸馏的核心是通过KL散度衡量学生模型与教师模型的输出差异,结合硬标签损失优化学生模型。

  1. 软标签生成
    教师模型的logits经过温度TTT调整后生成软标签:
    pi=exp⁡(zi/T)∑jexp⁡(zj/T)(5)p_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} \tag{5} pi=jexp(zj/T)exp(zi/T)(5)

    • ziz_izi:教师模型对第iii类的logits输出。
    • TTT:温度参数(T>1T>1T>1使分布更平滑,保留更多知识)。
  2. KL散度损失(软标签损失)
    衡量学生软标签qqq与教师软标签ppp的差异:
    LKL=∑ipilog⁡(piqi)(6)L_{\text{KL}} = \sum_i p_i \log\left(\frac{p_i}{q_i}\right) \tag{6} LKL=ipilog(qipi)(6)

    • T=1T=1T=1时,KL散度退化为交叉熵损失。
  3. 总损失函数
    Ltotal=α⋅LKL+(1−α)⋅LCE(7)L_{\text{total}} = \alpha \cdot L_{\text{KL}} + (1-\alpha) \cdot L_{\text{CE}} \tag{7} Ltotal=αLKL+(1α)LCE(7)

    • LCEL_{\text{CE}}LCE:学生模型与真实标签的交叉熵损失(硬标签损失)。
    • α\alphaα:软标签损失的权重(通常取0.5-0.9)。

3.2 蒸馏计算示例

以三分类任务为例,演示损失计算过程:

步骤1:模型输出

  • 教师模型logits:zteacher=[3.0,1.0,0.2]z_{\text{teacher}} = [3.0, 1.0, 0.2]zteacher=[3.0,1.0,0.2]
  • 学生模型logits:zstudent=[2.5,0.8,0.1]z_{\text{student}} = [2.5, 0.8, 0.1]zstudent=[2.5,0.8,0.1]
  • 真实标签:y=[1,0,0]y = [1, 0, 0]y=[1,0,0](第0类)

步骤2:生成软标签(T=2.0T=2.0T=2.0

  • 教师软标签:
    p=[exp⁡(3/2)∑,exp⁡(1/2)∑,exp⁡(0.2/2)∑]≈[0.721,0.215,0.064]p = \left[ \frac{\exp(3/2)}{\sum}, \frac{\exp(1/2)}{\sum}, \frac{\exp(0.2/2)}{\sum} \right] \approx [0.721, 0.215, 0.064] p=[exp(3/2),exp(1/2),exp(0.2/2)][0.721,0.215,0.064]
  • 学生软标签:
    q=[exp⁡(2.5/2)∑,exp⁡(0.8/2)∑,exp⁡(0.1/2)∑]≈[0.659,0.257,0.084]q = \left[ \frac{\exp(2.5/2)}{\sum}, \frac{\exp(0.8/2)}{\sum}, \frac{\exp(0.1/2)}{\sum} \right] \approx [0.659, 0.257, 0.084] q=[exp(2.5/2),exp(0.8/2),exp(0.1/2)][0.659,0.257,0.084]

步骤3:计算损失

  • KL散度损失(LKLL_{\text{KL}}LKL):
    LKL=0.721log⁡(0.721/0.659)+0.215log⁡(0.215/0.257)+0.064log⁡(0.064/0.084)≈0.018L_{\text{KL}} = 0.721\log(0.721/0.659) + 0.215\log(0.215/0.257) + 0.064\log(0.064/0.084) \approx 0.018 LKL=0.721log(0.721/0.659)+0.215log(0.215/0.257)+0.064log(0.064/0.084)0.018
  • 硬标签损失(LCEL_{\text{CE}}LCE):
    LCE=−log⁡(q0)≈−log⁡(0.659)≈0.418L_{\text{CE}} = -\log(q_0) \approx -\log(0.659) \approx 0.418 LCE=log(q0)log(0.659)0.418
  • 总损失(α=0.7\alpha=0.7α=0.7):
    Ltotal=0.7×0.018+0.3×0.418≈0.138L_{\text{total}} = 0.7 \times 0.018 + 0.3 \times 0.418 \approx 0.138 Ltotal=0.7×0.018+0.3×0.4180.138

3.3 知识蒸馏相关API详解

  1. Hugging Face Transformers API
API名称/工具功能描述关键参数说明
transformers.Trainer通过自定义损失函数实现蒸馏(教师模型固定,学生模型训练)- model:学生模型
- args:训练参数(如TrainingArguments
- compute_loss:自定义损失函数(融合KL散度和交叉熵)
transformers.DistilBertForSequenceClassification预训练蒸馏模型(如DistilBERT,学生模型)- 继承自PreTrainedModel,可直接加载预训练权重(如distilbert-base-uncased
  1. PyTorch蒸馏工具
API名称功能描述关键参数说明
torch.nn.KLDivLoss计算KL散度损失(软标签损失)- reduction:损失聚合方式(如'batchmean'
- log_target:是否目标为对数形式
torch.nn.CrossEntropyLoss计算交叉熵损失(硬标签损失)- weight:类别权重
- reduction:损失聚合方式
  1. 专用蒸馏库
库名称功能描述核心API示例
HuggingFace/transformers中的蒸馏工具提供DistilBERT、DistilRoBERTa等蒸馏模型的训练逻辑from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
knowledge-distillation-pytorch轻量级蒸馏库(支持多种蒸馏策略)from kd import KnowledgeDistillationLoss(融合KL散度和硬标签损失)

四、模型剪枝:移除冗余参数,保留核心能力

在这里插入图片描述

4.1 剪枝的数学原理

剪枝通过评估参数重要性移除冗余权重,常用L1范数衡量重要性(值越小越冗余)。

  1. L1范数重要性评估
    对于权重矩阵W∈Rm×nW \in \mathbb{R}^{m \times n}WRm×n,单个权重wijw_{ij}wij的重要性为:
    I(wij)=∣wij∣(8)\mathcal{I}(w_{ij}) = |w_{ij}| \tag{8} I(wij)=wij(8)

  2. 全局剪枝阈值计算
    若剪枝比例为rrr,则阈值θ\thetaθ满足:
    ∑i,jI(∣wij∣<θ)总参数数=r(9)\frac{\sum_{i,j} \mathbb{I}(|w_{ij}| < \theta)}{\text{总参数数}} = r \tag{9} 总参数数i,jI(wij<θ)=r(9)

    • I(⋅)\mathbb{I}(\cdot)I()为指示函数,满足条件时取1。

4.2 剪枝计算示例

以3x3权重矩阵为例,剪枝30%的参数:

步骤1:原始权重矩阵
W=[0.1−0.020.05−0.30.010.20.03−0.040.08]W = \begin{bmatrix} 0.1 & -0.02 & 0.05 \\ -0.3 & 0.01 & 0.2 \\ 0.03 & -0.04 & 0.08 \end{bmatrix} W=0.10.30.030.020.010.040.050.20.08

步骤2:计算L1范数(重要性)
∣I(W)∣=[0.10.020.050.30.010.20.030.040.08]|\mathcal{I}(W)| = \begin{bmatrix} 0.1 & 0.02 & 0.05 \\ 0.3 & 0.01 & 0.2 \\ 0.03 & 0.04 & 0.08 \end{bmatrix} I(W)=0.10.30.030.020.010.040.050.20.08

步骤3:排序并确定阈值
将所有权重按L1范数升序排列:0.01,0.02,0.03,0.04,0.05,0.08,0.1,0.2,0.30.01, 0.02, 0.03, 0.04, 0.05, 0.08, 0.1, 0.2, 0.30.01,0.02,0.03,0.04,0.05,0.08,0.1,0.2,0.3
总参数9个,剪枝30%即移除3个参数,阈值θ=0.03\theta=0.03θ=0.03(第3小的值)。

步骤4:剪枝后矩阵(小于θ\thetaθ的权重置0)
Wpruned=[0.100.05−0.300.20−0.040.08]W_{\text{pruned}} = \begin{bmatrix} 0.1 & 0 & 0.05 \\ -0.3 & 0 & 0.2 \\ 0 & -0.04 & 0.08 \end{bmatrix} Wpruned=0.10.30000.040.050.20.08

  • 稀疏度:3/9=33.3%(接近目标30%)。

4.3 模型剪枝相关API详解

  1. PyTorch剪枝API
API名称功能描述关键参数说明
torch.nn.utils.prune.l1_unstructured对单个模块进行L1非结构化剪枝- module:待剪枝模块(如model.bert.encoder.layer[0].attention.self.query
- name:待剪枝参数名(如'weight'
- amount:剪枝比例(如0.3
torch.nn.utils.prune.global_unstructured对多个模块进行全局非结构化剪枝(统一阈值)- parameters:待剪枝参数列表(如[(module, 'weight')]
- pruning_method:剪枝方法(如prune.L1Unstructured
- amount:剪枝比例
torch.nn.utils.prune.remove永久移除剪枝掩码(将0值权重保留在参数中)- module:已剪枝模块
- name:剪枝参数名
torch.nn.utils.prune.ln_structured对模块进行结构化剪枝(如按通道剪枝)- n:剪枝维度(如0表示按输出通道剪枝)
- amount:剪枝比例
- pruning_method:重要性评估方法(如'l1_unstructured'
  1. TensorFlow剪枝API
API名称功能描述关键参数说明
tfmot.sparsity.keras.prune_low_magnitude对Keras模型进行 magnitude-based 剪枝- model:待剪枝模型
- pruning_schedule:剪枝调度(如PolynomialDecay
tfmot.sparsity.keras.PolynomialDecay定义剪枝比例随训练步数的变化策略- initial_sparsity:初始稀疏度
- final_sparsity:目标稀疏度
- num_steps:总步数
  1. 第三方剪枝工具
工具名称功能描述核心特点
TorchPrune支持PyTorch模型的结构化和非结构化剪枝提供剪枝后模型微调工具,支持可视化剪枝效果
PruneTorch轻量级剪枝库(支持Transformer、ResNet等主流模型)实现简单,适合快速验证剪枝效果

4.4 剪枝注意事项

  • 非结构化剪枝生成稀疏矩阵,需硬件支持(如NVIDIA的Sparse Tensor Core)才能加速,否则可能变慢。
  • 结构化剪枝(如按通道)生成密集矩阵,无需特殊硬件,但剪枝比例过高会导致精度大幅下降。
  • 剪枝后需微调模型(fine-tuning),恢复因剪枝丢失的性能(通常微调3-5个epoch即可)。

五、总结

模型压缩技术通过数学原理与工程实现的结合,在精度与效率间取得平衡,其核心API为工业界部署提供了便捷工具:

  • 量化:通过torch.quantization.quantize_dynamic等API实现低精度转换,适合追求极致部署效率的场景,API使用简单但需注意精度权衡。
  • 蒸馏:基于KLDivLossCrossEntropyLoss组合,或使用DistilBERT等预训练蒸馏模型,适合需要保留高精度的小模型场景。
  • 剪枝:通过global_unstructured等API移除冗余参数,适合对模型大小敏感且可接受一定部署复杂度的场景。

实际应用中,可组合多种技术(如“剪枝+量化”)进一步提升压缩效果,例如先剪枝移除30%冗余参数,再量化为int8,可在精度损失5%以内实现模型体积缩减80%以上,推动大模型在边缘设备的落地。

http://www.dtcms.com/a/296813.html

相关文章:

  • 一篇文章了解HashMap和ConcurrentHashMap的扩容机制
  • ESP32入门实战:PC远程控制LED灯完整指南
  • pandas库的数据导入导出,缺失值,重复值处理和数据筛选,matplotlib库 简单图绘制
  • AD一张原理图分成多张原理图
  • iview Select的Option边框显示不全(DatePicker也会出现此类问题)
  • rust-参考与借用
  • 爬虫逆向--Day12--DrissionPage案例分析【小某书评价数据某东评价数据】
  • MySQL零基础教程增删改查实战
  • java后端
  • mujoco playground
  • DBA常用数据库查询语句
  • DevOps 完整实现指南:从理论到实践
  • 论文阅读:《Many-Objective Evolutionary Algorithms: A Survey. 》多目标优化问题的优化目标评估的相关内容介绍
  • Android LiveData 全面解析:原理、使用与最佳实践
  • Rust生态中的LLM实践全解析
  • 【C# 找最大值、最小值和平均值及大于个数和值】2022-9-23
  • 项目质量如何提升?
  • 教育培训系统源码如何赋能企业培训学习?功能设计与私有化部署实战
  • 使用 Vue 实现移动端视频录制与自动截图功能
  • MySQL---索引、事务
  • Docker 打包Vue3项目镜像
  • 互联网广告中的Header Bidding与瀑布流的解析与比较
  • 性能测试-groovy语言1
  • 使用 LLaMA 3 8B 微调一个 Reward Model:从入门到实践
  • 【硬件-笔试面试题】硬件/电子工程师,笔试面试题-19,(知识点:PCB布局布线的设计要点)
  • 类和包的可见性
  • 勾芡 3 步诀:家庭挂汁不翻车
  • Spring Data JPA 中的一个注解NoRepositoryBean
  • Edwards爱德华干泵报警信息表适用于iXH, iXL, iXS, iHand pXH
  • 机器学习的基础知识