当前位置: 首页 > news >正文

模型优化-------模型压缩

模型压缩是一种优化技术,目标是在尽量保留模型性能的前提下,减少模型的体积、计算成本和内存占用。特别适合模型部署在边缘设备、移动端、嵌入式系统等资源受限环境中。
其中,“剪枝(Pruning)、量化(Quantization)和知识蒸馏(Knowledge Distillation)”是最常用且研究最深入的三种方法。


一、剪枝(Pruning)

原理:

剪枝的核心思想是去掉对模型输出影响较小的参数或结构,使得模型更加稀疏或紧凑。

类型:

  1. 非结构化剪枝(Unstructured Pruning):

    • 精确到单个权重级别;

    • 比如将小于某阈值的权重置为 0;

    • 会产生稀疏矩阵,需要专门的硬件或库支持加速。

  2. 结构化剪枝(Structured Pruning):

    • 删除整个神经元、通道(channel)、卷积核或 transformer 中的 attention head;

    • 对硬件更友好,能获得实际的加速效果。

  3. 动态剪枝

    • 运行时根据输入动态决定哪些结构被跳过;

    • 典型例子如动态卷积、Early Exit 网络等。

一般流程:

训练完整模型 → 剪除低贡献参数 → 微调恢复性能


二、量化(Quantization)

原理:

量化是将模型参数和计算从高精度(如 FP32)转换为低精度(如 INT8、FP16),以减少模型大小、内存访问量和计算需求。

类型:

  1. 后训练量化(Post-Training Quantization, PTQ)

    • 不需要重新训练;

    • 简单高效,但可能引入较大精度损失;

    • 如 PyTorch 中 torch.quantization.quantize_dynamic()

  2. 量化感知训练(Quantization-Aware Training, QAT)

    • 在训练中加入量化模拟;

    • 准确率更高,但训练成本略高。

  3. 静态 vs 动态量化

    • 静态量化会量化权重和激活;

    • 动态量化只量化权重,激活在运行时动态计算。

优势:

  • 精度损失可控(尤其是 QAT);

  • 模型大小缩小约 4 倍;

  • 推理速度明显加快(特别在支持低精度硬件上);


三、知识蒸馏(Knowledge Distillation)

原理:

将一个大模型(teacher)的知识迁移给一个小模型(student)。student 模型不仅学习真实标签,还学习 teacher 输出的“软标签”(概率分布)。

优势:

  • 学习到类与类之间的相对关系(从 softmax 输出中学习);

  • student 模型可以比单独训练更轻量、更有效;

  • 常用于压缩 BERT、GPT 等大型模型,如:DistilBERT、TinyBERT、MiniLM。

蒸馏方式:

  • logits 蒸馏:使用 teacher 输出的 softmax 概率作为目标;

  • 中间层蒸馏:对比 student 和 teacher 的隐层输出;

  • 多任务蒸馏:同时结合标签监督和模型蒸馏。


四、其他压缩技术(补充)

1. 参数共享(Weight Sharing)

  • 用哈希函数或其它规则让多个参数共享同一值;

  • 应用于 BERT-of-Theseus、ALBERT 等。

2. 低秩分解(Low-rank Factorization)

  • 将大的矩阵表示分解成多个小矩阵的乘积;

  • 常见如 SVD 分解,用于减少全连接层、注意力矩阵的计算。

3. 神经架构搜索(NAS)

  • 自动搜索性能与效率兼顾的模型结构;

  • 如 MobileNet、EfficientNet、FBNet 等都是通过 NAS 获得的紧凑模型。

4. 混合精度训练(Mixed Precision Training)

  • 训练时同时使用 FP16 和 FP32;

  • 可减少显存占用、提高训练速度,同时保持数值稳定性。


五、总结对比

方法是否需再训练压缩目标精度影响实际部署效率应用场景
剪枝是(微调)参数/结构可控中~高CNN、Transformer
量化选配(PTQ/QAT)位宽/存储可控~小移动端、边缘设备
蒸馏模型结构小~可提高教学模型、小模型
参数共享参数冗余可控多层结构模型
低秩分解大矩阵全连接/注意力模块
NAS模型结构不确定自动模型压缩
http://www.dtcms.com/a/293048.html

相关文章:

  • Python之格式化Conda中生成的requirements.txt
  • timesFM安装记录
  • JavaWeb学习打卡10(HttpServletRequest详解应用、获取参数,请求转发实例)
  • PyTorch常用工具
  • 我的第一个开源项目 -- 实时语音识别工具
  • C++中的list(2)简单复现list中的关键逻辑
  • 水电站自动化升级:Modbus TCP与DeviceNet的跨协议协同应用
  • CMake实践:CMake3.30版本之前和之后链接boost的方式差异
  • 渗透部分总结
  • 从 COLMAP 到 3D Gaussian Splatting
  • vue2的scoped 原理
  • Flex/Bison(腾讯元宝)
  • 开源AI智能客服、AI智能名片与S2B2C商城小程序在客户复购与转介绍中的协同效应研究
  • 禁食时长与关键生物反应的相对强度对照表
  • syscall函数用法
  • Java 中 String 类的常用方法
  • JavaScript的进阶学习--函数和基本对象的解析
  • 16-MSTP
  • 加速度计输出值的正负号与坐标系正方向相反
  • 基于 Agent 的股票分析工具
  • Windows Server 设置MySQL自动备份任务(每日凌晨2点执行)
  • 洛谷刷题7..22
  • 贪心算法Day3学习心得
  • VBScript 拖拽文件显示路径及特殊字符处理
  • gitlab私服搭建
  • 根据数据,判断神经网络所需的最小参数量
  • 如何搭建appium工具环境?
  • 嵌入式学习-土堆目标检测(2)-day26
  • 浏览器解码顺序xss
  • UE5 UI WarpBox 包裹框