当前位置: 首页 > news >正文

深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)

一、背景:为什么需要模型剪枝?

随着深度学习的发展,模型参数量和计算量呈指数级增长。以ResNet18为例,其在ImageNet上的参数量约为1100万,虽然在服务器端运行流畅,但在移动端或嵌入式设备上部署时,内存和计算资源的限制使得直接使用大模型变得困难。模型剪枝(Model Pruning)作为模型压缩的核心技术之一,通过删除冗余的神经元或通道,在保持模型性能的前提下显著降低模型大小和计算量,是解决这一问题的关键手段。
在前面一篇文章我们也提到了模型压缩的一些基本定义和核心原理:《深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏》。

本文将基于PyTorch框架,以ResNet18在CIFAR-10数据集上的分类任务为例,详细讲解结构化通道剪枝的完整实现流程,包括模型训练、剪枝策略、剪枝后结构调整、微调及效果评估。

二、整体流程概览

本文代码的核心流程可总结为以下6步:

  1. 环境初始化与数据集加载
  2. 原始模型训练与评估
  3. 卷积层结构化剪枝(以conv1层为例)
  4. 剪枝后模型结构调整(BN层、残差下采样层等)
  5. 剪枝模型微调
  6. 剪枝前后模型效果对比
    特地说明:在这里选择conv1层作为例子,不是因为选择这个就会效果更好。

三、关键步骤代码解析

3.1 环境初始化与数据集准备

首先需要配置计算设备(GPU/CPU),并加载CIFAR-10数据集。CIFAR-10包含10类32x32的彩色图像,训练集5万张,测试集1万张。

def setup_device():return torch.device("cuda" if torch.cuda.is_available() else "cpu")def load_dataset():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 归一化到[-1,1]])train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)return train_dataset, test_dataset

3.2 原始模型训练

使用预训练的ResNet18模型,修改全连接层输出为10类(匹配CIFAR-10的类别数),并进行5轮训练:

def create_model(device):model = models.resnet18(pretrained=True)  # 加载ImageNet预训练权重model.fc = nn.Linear(512, 10)  # 修改输出层为10类return model.to(device)def train_model(model, train_loader, criterion, optimizer, device, epochs=3):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in tqdm(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")return model

3.3 结构化通道剪枝核心实现

本文重点是对卷积层进行结构化剪枝(按通道剪枝),具体步骤如下:

3.3.1 计算通道重要性

通过计算卷积核的L2范数评估通道重要性。假设卷积层权重维度为[out_channels, in_channels, kernel_h, kernel_w],将每个输出通道的权重展平为一维向量,计算其L2范数,范数越小表示该通道对模型性能贡献越低,越应被剪枝。

layer = dict(model.named_modules())[layer_name]  # 获取目标卷积层
weight = layer.weight.data
channel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)  # 计算每个输出通道的L2范数
3.3.2 生成剪枝掩码

根据剪枝比例(如20%),选择范数最小的通道生成掩码:

num_channels = weight.shape[0]  # 原始输出通道数(如ResNet18的conv1层为64)
num_prune = int(num_channels * amount)  # 需剪枝的通道数(如64*0.2=12)
_, indices = torch.topk(channel_norm, k=num_prune, largest=False)  # 找到最不重要的12个通道mask = torch.ones(num_channels, dtype=torch.bool)
mask[indices] = False  # 掩码:保留的通道标记为True(52个),剪枝的标记为False(12个)
3.3.3 替换卷积层

创建新的卷积层,仅保留掩码为True的通道:

new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,  # 剪枝后输出通道数(52)kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None
).to(device)  # 移动到模型所在设备new_conv.weight.data = layer.weight.data[mask]  # 保留掩码为True的通道权重
if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]  # 偏置同理
3.3.4 关键:剪枝后结构调整

直接剪枝会导致后续层(如BN层、残差连接中的下采样层)的输入/输出通道不匹配,必须同步调整:

(1) 调整BN层
卷积层后通常接BN层,BN的num_features需与卷积输出通道数一致:

if 'conv1' in layer_name:bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)  # 新BN层通道数52with torch.no_grad():# 同步原始BN层的参数(仅保留未被剪枝的通道)new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1

(2) 调整残差下采样层
ResNet的残差块(如layer1.0)中,若主路径的通道数被剪枝,需要通过1x1卷积的下采样层(downsample)匹配 shortcut 的通道数:

block = model.layer1[0]
if not hasattr(block, 'downsample') or block.downsample is None:# 原始无downsample,创建新的1x1卷积+BNdownsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,  # 52(剪枝后的conv1输出)out_channels=block.conv2.out_channels,  # 64(主路径conv2的输出)kernel_size=1,stride=1,bias=False).to(device)torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')  # 初始化权重downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)block.downsample = nn.Sequential(downsample_conv, downsample_bn)  # 添加downsample层
else:# 原有downsample层,调整输入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channels  # 输入通道改为52downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 输入通道用掩码筛选

(3) 前向传播验证
调整后需验证模型能否正常前向传播,避免通道不匹配导致的错误:

with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)  # 测试输入(B, C, H, W)try:model(test_input)print("✅ 前向传播验证通过")except Exception as e:print(f"❌ 验证失败: {str(e)}")raise

3.3的总结,直接上代码

def prune_conv_layer(model, layer_name, amount=0.2):# 获取模型当前所在设备device = next(model.parameters()).device  # 新增:获取设备layer = dict(model.named_modules())[layer_name]weight = layer.weight.datachannel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)num_channels = weight.shape[0]  # 原始通道数(如 64)num_prune = int(num_channels * amount)_, indices = torch.topk(channel_norm, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False  # 生成剪枝掩码(长度 64,52 个 True)new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,  # 剪枝后通道数(如 52)kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None)new_conv = new_conv.to(device)  # 新增:移动到模型所在设备new_conv.weight.data = layer.weight.data[mask]  # 保留 mask 为 True 的通道if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]# 替换原始卷积层parent_name, sep, name = layer_name.rpartition('.')parent = model.get_submodule(parent_name)setattr(parent, name, new_conv)if 'conv1' in layer_name:# 1. 更新与 conv1 直接关联的 BN1 层bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels)  # 新 BN 层通道数 52new_bn1 = new_bn1.to(device)  # 新增:移动到模型所在设备with torch.no_grad():new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1# 2. 处理残差连接中的 downsample(关键修正:添加缺失的 downsample)block = model.layer1[0]if not hasattr(block, 'downsample') or block.downsample is None:# 原始无 downsample,需创建新的 1x1 卷积+BN 来匹配通道downsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,  # 52out_channels=block.conv2.out_channels,  # 64(主路径输出通道数)kernel_size=1,stride=1,bias=False)downsample_conv = downsample_conv.to(device)  # 新增:移动到模型所在设备# 初始化 1x1 卷积权重(这里简单复制原模型可能的统计量,实际可根据需求调整)torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)downsample_bn = downsample_bn.to(device)  # 新增:移动到模型所在设备with torch.no_grad():# 初始化 BN 参数(可保持默认,或根据原模型统计量调整)downsample_bn.weight.fill_(1.0)downsample_bn.bias.zero_()downsample_bn.running_mean.zero_()downsample_bn.running_var.fill_(1.0)block.downsample = nn.Sequential(downsample_conv, downsample_bn)print("✅ 为 layer1.0 添加新的 downsample 层")else:# 原有 downsample 层,调整输入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channels  # 输入通道调整为 52downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 输入通道用 mask 筛选downsample_conv = downsample_conv.to(device)  # 新增:移动到模型所在设备downsample_bn = block.downsample[1]new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)new_downsample_bn = new_downsample_bn.to(device)  # 新增:移动到模型所在设备with torch.no_grad():new_downsample_bn.weight.data = downsample_bn.weight.data.clone()new_downsample_bn.bias.data = downsample_bn.bias.data.clone()new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()block.downsample[1] = new_downsample_bn# 3. 同步 layer1.0.conv1 的输入通道(保持原有逻辑)next_convs = ['layer1.0.conv1']for conv_path in next_convs:try:conv = model.get_submodule(conv_path)if conv.in_channels != new_conv.out_channels:print(f"同步输入通道: {conv.in_channels}{new_conv.out_channels}")conv.in_channels = new_conv.out_channelsconv.weight = nn.Parameter(conv.weight.data[:, mask, :, :].clone())conv = conv.to(device)  # 新增:移动到模型所在设备except AttributeError as e:print(f"⚠️ 卷积层调整失败: {conv_path} ({str(e)})")# 验证前向传播with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)  # 确保测试输入也在相同设备try:model(test_input)print("✅ 前向传播验证通过")except Exception as e:print(f"❌ 验证失败: {str(e)}")raisereturn model

3.4 剪枝模型微调

剪枝后模型的部分参数被删除,需要通过微调恢复性能。一开始,我们只是在微调时冻结了除 fc 层外的所有参数,但是效果并不好,当然分析原因,除了动了conv1的原因(conv1 是模型的第一个卷积层,负责提取最基础的图像特征(如边缘、纹理、颜色等)。这些底层特征对后续所有层的特征提取至关重要。),最重要的是裁剪后,需要对裁剪的层进行微调,确保参数适应新的特征维度。
微调时冻结了除 fc 层外的所有参数的代码和结果:

for name, param in pruned_model.named_parameters():if 'fc' not in name:param.requires_grad = Falseoptimizer = optim.Adam(pruned_model.fc.parameters(), lr=0.001)print("微调剪枝后的模型")pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device,epochs=5)
原始模型准确率: 80.07%
剪枝后模型准确率: 37.80%

可以看到这个相差很大
本文选择解冻被剪枝的层(如conv1bn1)及相关层(如layer1.0.conv1downsample)进行参数更新:

print("开始微调剪枝后的模型")
for name, param in pruned_model.named_parameters():# 仅解冻与剪枝相关的层if 'conv1' in name or 'bn1' in name or 'layer1.0.conv1' in name or 'layer1.0.downsample' in name or 'fc' in name:param.requires_grad = Trueelse:param.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.001)
pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=5)
原始模型准确率: 78.94%
剪枝后模型准确率:  81.30%

重新微调了裁剪后的层后,结果有了很大改变。

四、实验结果与分析

通过代码中的evaluate_model函数评估剪枝前后的模型准确率:

def evaluate_model(model, device, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()acc = 100 * correct / totalreturn acc

假设原始模型准确率为88.5%,剪枝20%通道后(模型大小降低约20%),通过微调可恢复至87.2%,验证了剪枝策略的有效性。

五、总结与改进方向

本文实现了基于通道L2范数的结构化剪枝,重点解决了剪枝后模型结构不一致的问题(如BN层、残差下采样层的调整),并通过微调恢复了模型性能。
在这个例子中,仅裁剪 conv1 层的影响
仅裁剪 conv1 层对模型的影响极大,原因如下:

  • 底层特征的重要性 : conv1 输出的是最基础的图像特征,所有后续层的特征均基于此生成。裁剪 conv1 会直接限制后续所有层的特征表达能力。
  • 结构连锁反应 : conv1 的输出通道减少会触发 bn1 、 layer1.0.conv1 、 downsample 等多个模块的调整,任何一个模块的调整失误(如通道数不匹配、参数初始化不当)都会导致整体性能下降。
    实际应用中可从以下方向改进:

模型裁剪通常优先选择 中间层(如ResNet的 layer2 、 layer3 ) ,而非底层或顶层,原因如下:

  • 底层(如 conv1 ) :负责基础特征提取,裁剪后特征损失大,对性能影响显著。
  • 中间层(如 layer2 、 layer3 ) :特征具有一定抽象性但冗余度高(同一层的多个通道可能提取相似特征),裁剪后对性能影响较小。
  • 顶层(如 fc 层) :负责分类决策,参数密度高但冗余度低,裁剪易导致分类能力下降。

相关文章:

  • 使用ArcPy进行栅格数据分析(2)
  • 【时时三省】(C语言基础)多维数组名作函数参数
  • 树莓派超全系列教程文档--(55)如何使用网络文件系统NFS
  • Symbol as Points: Panoptic Symbol Spotting via Point-based Representation
  • 【Redis】Redis 的常见客户端汇总
  • 《Sora模型中Transformer如何颠覆U-Net》
  • SpringBoot3项目架构设计与模块解析
  • 制作官网水平导航栏
  • Grafana-ECharts应用讲解(玫瑰图示例)
  • 计算机组成原理(计算篇)
  • minimatch 详解:功能、语法与应用场景
  • quickbi-突出显示指定行
  • STL——栈和队列和优先队列
  • 【计组】真题 2015 大题
  • SELinux是什么以及如何编写SELinux策略
  • 【YOLO 系列】基于YOLO的飞机表面缺陷智能检测系统【python源码+Pyqt5界面+数据集+训练代码】
  • USB-C/HDMI 2.0 2:1 SW,支持4K60HZ
  • Vue3实现拖拽改变元素大小
  • 2025年ESWA SCI1区TOP,元组引导差分进化算法TLDE+黑箱优化,深度解析+性能实测
  • 蒙特卡罗模拟: 高级应用的思路和实例
  • 东莞最新疫情地区/聊城seo
  • 360怎么免费建网站/荨麻疹怎么治疗能除根
  • 做网站产品资料表格/友情链接交换形式有哪些
  • 株洲网站建设/文明seo技术教程网
  • 常州网站建设公司平台/今日热点
  • 网站建设的一些销售技巧/软件外包公司排名