当前位置: 首页 > news >正文

AI、人工智能基础: 模型剪枝的概念与实践(PyTorch版)

胡说八道:

各位观众老爷,大家好,我是诗人啊_,今天和各位分享模型剪枝的相关知识和操作,一文速通~
屏幕前的你,帅气低调有内涵,美丽大方很优雅… 所以,求个点赞、收藏、关注呗~
正经标题模型剪枝理论入门及 PyTorch API 实战
此文讲解 torch.nn.utils.prune 模块的使用,模型剪枝的执行步骤请看 ↓↓↓↓↓

模型剪枝的概念与实践(PyTorch版)

前言

深度神经网络的大型预训练模型往往依赖庞大的参数量实现SOTA效果,但生物神经网络却通过稀疏连接完成复杂任务。模型剪枝正是受此启发,通过将稠密连接转化为稀疏连接,在保持性能的前提下压缩模型,本文基于PyTorch详细介绍模型剪枝的概念与实操。
在这里插入图片描述

一、什么是模型剪枝?

  • 核心思想:仿照生物神经网络的稀疏连接特性,移除冗余参数或结构,实现模型压缩与加速。
  • 本质:将稠密网络转化为稀疏网络,在精度损失可接受的范围内减少参数量和计算量。
  • PyTorch支持:需使用torch.nn.utils.prune模块,要求PyTorch版本≥1.4.0,支持多种剪枝方式:
    • 特定网络模块的剪枝
    • 多参数模块的剪枝
    • 全局剪枝
    • 用户自定义剪枝

在这里插入图片描述

二、剪枝的基本原理(以LeNet为例)

2.1 准备工作

先定义经典LeNet网络作为示例:

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, 3)  # 输入1通道,输出6通道,3x3卷积核self.conv2 = nn.Conv2d(6, 16, 3)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, int(x.nelement() / x.shape[0]))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = LeNet().to(device=device)

2.2 剪枝核心机制:掩码(Mask)

剪枝通过掩码张量实现参数筛选,核心逻辑如下:

  1. 原始参数(如weight)被拆分为:
    • weight_orig:保留原始参数值(可训练)
    • weight_mask:掩码张量(0表示剪枝移除,1表示保留)
  2. 实际使用的参数weight = weight_orig * weight_mask(被掩码为0的参数失效)
  3. 剪枝后,weight从可训练参数(Parameter)变为普通属性(Attribute)

2.3 单模块剪枝示例

conv1层的weight参数为例,执行随机非结构化剪枝:

module = model.conv1
# 对conv1的weight参数剪枝30%
prune.random_unstructured(module, name="weight", amount=0.3)
剪枝后参数变化:
  • named_parameters()weight变为weight_orig(保留原始值)
  • named_buffers()中新增weight_mask(掩码张量)
  • module.weightweight_orig * weight_mask的结果(含0值的剪枝后参数)
# 剪枝后参数查看
print("参数列表:", list(module.named_parameters()))  # 含weight_orig、bias
print("掩码列表:", list(module.named_buffers()))      # 含weight_mask
print("剪枝后weight:\n", module.weight)               # 含0值的有效参数

2.4 剪枝永久化(remove操作)

剪枝默认是临时的,执行prune.remove()可将掩码效果永久应用到参数:

# 永久化剪枝(无法撤销)
prune.remove(module, 'weight')
永久化后变化:
  • weight_orig消失,weight恢复为可训练参数(值 = 剪枝后的有效参数)
  • weight_mask被移除(无需保留)

三、常见剪枝方式实战

3.1 特定模块剪枝

针对单个模块的特定参数(如weightbias)剪枝,支持多种策略:

剪枝函数作用适用场景
random_unstructured随机移除单个参数非结构化剪枝(单权重)
l1_unstructured移除L1范数最小的单个参数非结构化剪枝(优先移除小值)
ln_structured移除Lₙ范数最小的结构化单元结构化剪枝(通道/神经元)
示例:对bias参数执行L1剪枝
# 对conv1的bias参数剪枝3个(绝对值最小的3个)
prune.l1_unstructured(module, name="bias", amount=3)
print("剪枝后bias:", module.bias)  # 含0值的剪枝后偏置

3.2 多参数模块剪枝

对模型中多个模块批量剪枝(如所有卷积层/全连接层):

# 对所有卷积层和全连接层分别剪枝
for name, module in model.named_modules():if isinstance(module, nn.Conv2d):# 卷积层:L1非结构化剪枝20%prune.l1_unstructured(module, name="weight", amount=0.2)elif isinstance(module, nn.Linear):# 全连接层:L2结构化剪枝40%prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)
效果:
  • 所有卷积层的weight均被剪枝20%
  • 所有全连接层的weight均被剪枝40%
  • 每个模块独立生成weight_origweight_mask

3.3 全局剪枝(Global Pruning)

局部剪枝(单模块/多模块)要求每层剪枝比例固定,而全局剪枝以整个网络为单位分配剪枝比例(总剪枝量固定,每层比例自适应)。

示例:全局剪枝20%参数
# 定义参与剪枝的模块和参数
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),(model.fc2, 'weight'),(model.fc3, 'weight')
)# 全局剪枝20%(总参数量的20%)
prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2
)
特点:
  • 总剪枝比例固定(如20%),但每层剪枝比例不同
  • 重要性低的层(参数值小)会被剪枝更多
# 查看各层剪枝比例
print("conv1稀疏度:{:.2f}%".format(100 * torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()
))
print("全局总稀疏度:{:.2f}%".format(100 * (torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + ...) / (model.conv1.weight.nelement() + model.conv2.weight.nelement() + ...)
))

3.4 用户自定义剪枝

通过继承BasePruningMethod实现自定义剪枝规则,只需重写__init__compute_mask方法。

示例:每隔一个参数剪枝一个(50%比例)
class MyPruningMethod(prune.BasePruningMethod):PRUNING_TYPE = "unstructured"  # 非结构化剪枝(单参数)def compute_mask(self, t, default_mask):mask = default_mask.clone()# 自定义规则:每隔一个参数剪枝一个(索引为偶数的置0)mask.view(-1)[::2] = 0return mask# 封装为剪枝函数
def my_unstructured_pruning(module, name):MyPruningMethod.apply(module, name)return module# 对fc3的bias参数应用自定义剪枝
my_unstructured_pruning(model.fc3, name="bias")
print("自定义剪枝掩码:", model.fc3.bias_mask)  # 0和1交替出现

四、剪枝模型的序列化

剪枝后的模型状态字典(state_dict)会保留:

  • 原始参数:weight_origbias_orig
  • 掩码张量:weight_maskbias_mask
# 剪枝前后状态字典对比
print("剪枝前:", model.state_dict().keys())
# 执行剪枝...
print("剪枝后:", model.state_dict().keys())  # 含orig和mask

总结

  1. 核心逻辑:通过掩码张量筛选参数,实现模型稀疏化
  2. 关键操作:单模块剪枝→多模块批量剪枝→全局剪枝→自定义剪枝
  3. 实用技巧
    • 非结构化剪枝(单权重)适合压缩模型,结构化剪枝(通道/神经元)适合加速推理
    • 剪枝后建议微调模型,恢复精度损失
    • 永久化剪枝(remove)可减小模型存储体积

通过合理的剪枝策略,可在保持模型性能的同时显著降低参数量和计算成本,是模型部署的重要优化手段。

我是诗人啊_程序员,致力于分享人工智能方面的知识,近期 NLP 自然语言处理系列文章发布中,如果感兴趣,来个关注呗~

http://www.dtcms.com/a/354276.html

相关文章:

  • uvloop深度实践:从原理到高性能异步应用实战
  • 死锁产生的条件是什么? 如何进行死锁诊断?
  • 本地部署DeepSeek大模型的基本方法
  • 自定义命令行补全机制的核心工具之compgen
  • web服务组件
  • MII的原理
  • 软件设计师备考-(三)操作系统基本原理
  • leetcode28. 找出字符串中第一个匹配项的下标
  • VR党建工作站-红色教育基地
  • 路由基础(三):静态路由、动态路由、默认路由
  • Linux系统 -- 线程(pthread)核心知识整理
  • 【golang长途旅行第33站】常量------补充知识点
  • 学习游戏制作记录(数据加密以及主菜单和画面优化)8.27
  • 运算电源抑制比(PSRR)测量及设计注意事项
  • 去哪里学AI?2025年AI培训机构推荐!
  • 部署k8s-efk日志收集服务(小白的“升级打怪”成长之路)
  • 数据库:缓冲池和磁盘I/O
  • 让组件“活”起来:使用 `useState` Hook 管理组件状态
  • 【苍穹外卖项目】Day12
  • Android中的SELinux
  • vue3 字符 居中显示
  • HyperMesh许可证过期?
  • 北京国标:专业高效的数据采集和分析服务
  • 【深入理解 Linux 网络】配置调优与性能优化
  • 官宣,2026第二届郑州国际台球产业展览会,展位开启招商
  • 解决网站图片加载慢:从架构原理到实践
  • Ubuntu系统中查看内存、CPU、GPU的使用情况以及它们之间的连接情况
  • TypeScript实战:轻松实现数字序号转中文大写数字
  • 什么是宏观和微观仿真
  • Wed 自动化测试常用函数实践(二)