当前位置: 首页 > news >正文

网站的免费空间是什么深圳宝安网站设计公司

网站的免费空间是什么,深圳宝安网站设计公司,广告制作合同范本免费,wordpress搭建会员胡说八道: 各位观众老爷,大家好,我是诗人啊_,今天和各位分享模型剪枝的相关知识和操作,一文速通~ (屏幕前的你,帅气低调有内涵,美丽大方很优雅… 所以,求个点赞、收藏、关…

胡说八道:

各位观众老爷,大家好,我是诗人啊_,今天和各位分享模型剪枝的相关知识和操作,一文速通~
屏幕前的你,帅气低调有内涵,美丽大方很优雅… 所以,求个点赞、收藏、关注呗~
正经标题模型剪枝理论入门及 PyTorch API 实战
此文讲解 torch.nn.utils.prune 模块的使用,模型剪枝的执行步骤请看 ↓↓↓↓↓

模型剪枝的概念与实践(PyTorch版)

前言

深度神经网络的大型预训练模型往往依赖庞大的参数量实现SOTA效果,但生物神经网络却通过稀疏连接完成复杂任务。模型剪枝正是受此启发,通过将稠密连接转化为稀疏连接,在保持性能的前提下压缩模型,本文基于PyTorch详细介绍模型剪枝的概念与实操。
在这里插入图片描述

一、什么是模型剪枝?

  • 核心思想:仿照生物神经网络的稀疏连接特性,移除冗余参数或结构,实现模型压缩与加速。
  • 本质:将稠密网络转化为稀疏网络,在精度损失可接受的范围内减少参数量和计算量。
  • PyTorch支持:需使用torch.nn.utils.prune模块,要求PyTorch版本≥1.4.0,支持多种剪枝方式:
    • 特定网络模块的剪枝
    • 多参数模块的剪枝
    • 全局剪枝
    • 用户自定义剪枝

在这里插入图片描述

二、剪枝的基本原理(以LeNet为例)

2.1 准备工作

先定义经典LeNet网络作为示例:

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 6, 3)  # 输入1通道,输出6通道,3x3卷积核self.conv2 = nn.Conv2d(6, 16, 3)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, int(x.nelement() / x.shape[0]))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xmodel = LeNet().to(device=device)

2.2 剪枝核心机制:掩码(Mask)

剪枝通过掩码张量实现参数筛选,核心逻辑如下:

  1. 原始参数(如weight)被拆分为:
    • weight_orig:保留原始参数值(可训练)
    • weight_mask:掩码张量(0表示剪枝移除,1表示保留)
  2. 实际使用的参数weight = weight_orig * weight_mask(被掩码为0的参数失效)
  3. 剪枝后,weight从可训练参数(Parameter)变为普通属性(Attribute)

2.3 单模块剪枝示例

conv1层的weight参数为例,执行随机非结构化剪枝:

module = model.conv1
# 对conv1的weight参数剪枝30%
prune.random_unstructured(module, name="weight", amount=0.3)
剪枝后参数变化:
  • named_parameters()weight变为weight_orig(保留原始值)
  • named_buffers()中新增weight_mask(掩码张量)
  • module.weightweight_orig * weight_mask的结果(含0值的剪枝后参数)
# 剪枝后参数查看
print("参数列表:", list(module.named_parameters()))  # 含weight_orig、bias
print("掩码列表:", list(module.named_buffers()))      # 含weight_mask
print("剪枝后weight:\n", module.weight)               # 含0值的有效参数

2.4 剪枝永久化(remove操作)

剪枝默认是临时的,执行prune.remove()可将掩码效果永久应用到参数:

# 永久化剪枝(无法撤销)
prune.remove(module, 'weight')
永久化后变化:
  • weight_orig消失,weight恢复为可训练参数(值 = 剪枝后的有效参数)
  • weight_mask被移除(无需保留)

三、常见剪枝方式实战

3.1 特定模块剪枝

针对单个模块的特定参数(如weightbias)剪枝,支持多种策略:

剪枝函数作用适用场景
random_unstructured随机移除单个参数非结构化剪枝(单权重)
l1_unstructured移除L1范数最小的单个参数非结构化剪枝(优先移除小值)
ln_structured移除Lₙ范数最小的结构化单元结构化剪枝(通道/神经元)
示例:对bias参数执行L1剪枝
# 对conv1的bias参数剪枝3个(绝对值最小的3个)
prune.l1_unstructured(module, name="bias", amount=3)
print("剪枝后bias:", module.bias)  # 含0值的剪枝后偏置

3.2 多参数模块剪枝

对模型中多个模块批量剪枝(如所有卷积层/全连接层):

# 对所有卷积层和全连接层分别剪枝
for name, module in model.named_modules():if isinstance(module, nn.Conv2d):# 卷积层:L1非结构化剪枝20%prune.l1_unstructured(module, name="weight", amount=0.2)elif isinstance(module, nn.Linear):# 全连接层:L2结构化剪枝40%prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)
效果:
  • 所有卷积层的weight均被剪枝20%
  • 所有全连接层的weight均被剪枝40%
  • 每个模块独立生成weight_origweight_mask

3.3 全局剪枝(Global Pruning)

局部剪枝(单模块/多模块)要求每层剪枝比例固定,而全局剪枝以整个网络为单位分配剪枝比例(总剪枝量固定,每层比例自适应)。

示例:全局剪枝20%参数
# 定义参与剪枝的模块和参数
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),(model.fc2, 'weight'),(model.fc3, 'weight')
)# 全局剪枝20%(总参数量的20%)
prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2
)
特点:
  • 总剪枝比例固定(如20%),但每层剪枝比例不同
  • 重要性低的层(参数值小)会被剪枝更多
# 查看各层剪枝比例
print("conv1稀疏度:{:.2f}%".format(100 * torch.sum(model.conv1.weight == 0) / model.conv1.weight.nelement()
))
print("全局总稀疏度:{:.2f}%".format(100 * (torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + ...) / (model.conv1.weight.nelement() + model.conv2.weight.nelement() + ...)
))

3.4 用户自定义剪枝

通过继承BasePruningMethod实现自定义剪枝规则,只需重写__init__compute_mask方法。

示例:每隔一个参数剪枝一个(50%比例)
class MyPruningMethod(prune.BasePruningMethod):PRUNING_TYPE = "unstructured"  # 非结构化剪枝(单参数)def compute_mask(self, t, default_mask):mask = default_mask.clone()# 自定义规则:每隔一个参数剪枝一个(索引为偶数的置0)mask.view(-1)[::2] = 0return mask# 封装为剪枝函数
def my_unstructured_pruning(module, name):MyPruningMethod.apply(module, name)return module# 对fc3的bias参数应用自定义剪枝
my_unstructured_pruning(model.fc3, name="bias")
print("自定义剪枝掩码:", model.fc3.bias_mask)  # 0和1交替出现

四、剪枝模型的序列化

剪枝后的模型状态字典(state_dict)会保留:

  • 原始参数:weight_origbias_orig
  • 掩码张量:weight_maskbias_mask
# 剪枝前后状态字典对比
print("剪枝前:", model.state_dict().keys())
# 执行剪枝...
print("剪枝后:", model.state_dict().keys())  # 含orig和mask

总结

  1. 核心逻辑:通过掩码张量筛选参数,实现模型稀疏化
  2. 关键操作:单模块剪枝→多模块批量剪枝→全局剪枝→自定义剪枝
  3. 实用技巧
    • 非结构化剪枝(单权重)适合压缩模型,结构化剪枝(通道/神经元)适合加速推理
    • 剪枝后建议微调模型,恢复精度损失
    • 永久化剪枝(remove)可减小模型存储体积

通过合理的剪枝策略,可在保持模型性能的同时显著降低参数量和计算成本,是模型部署的重要优化手段。

我是诗人啊_程序员,致力于分享人工智能方面的知识,近期 NLP 自然语言处理系列文章发布中,如果感兴趣,来个关注呗~

http://www.dtcms.com/a/547609.html

相关文章:

  • 创新的企业网站建设校园网登录入口
  • 做的网站怎么打开是白板网站开发基于什么平台
  • 网站建设费用用常用的seo工具的是有哪些
  • 免费站长工具网站推广现状
  • 最低成本做企业网站高校对网站建设的重视
  • 网站做优化有用吗wordpress小说主题
  • 毕业设计模板网站上海广告牌制作公司
  • 网站设计 优帮云潍坊网站建设 绮畅
  • 在win10下建设网站做logo宣传语的网站
  • app 网站平台建设实施方案单页面seo优化
  • 网站开发seo电商网站 设计方案
  • 苏州吴江建设局招投标网站昆明做网站做的好的公司有哪些
  • 聚合页做的比较好的教育网站wordpress 笔记本主题
  • 建设银行跨行转账网站做企业网站应该注意什么
  • 网站建设搜索代码杰讯山西网站建设
  • 营销型集团网站上海工商网企业查询网
  • 可以自己做网站做宣传吗长沙优化官网收费标准
  • 展示网站如何做公司营销外包
  • 面料 做网站网络营销组织是什么
  • 盐山网站大连科技网站制作
  • 北海手机网站制作网站建设张家港
  • 苏州网站开发服务对接标准做好门户网站建设
  • 建造电商网站大通网站建设
  • 安全员怎么网站中做备案企业手机网站建设策划书
  • 校园二手信息网站建设wordpress爆破软件
  • 网站建设分金手指专业十成都快速建站公司
  • 汕头做网站优化公司河南app软件开发价位
  • 权威网站有哪些畜牧养殖企业网站源码
  • 资兴网站设计wordpress创建配置文件
  • 石家庄网站设计公司排名莆田做网站建设