当前位置: 首页 > news >正文

深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(3)

引言

前面的文章《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)》和《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(2)》有做了相应的裁剪说明和实践,但是只是对其中的一个层进行采集的,这篇文章是记录对ResNet18中所有的残差层进行采集的一个过程。当然,前面也提到第一层是没有进行裁剪的,原因可以自己翻看前面的原因,后面也会有提到。


一、ResNet18模型结构全景图

ResNet18是经典的轻量级残差网络,其核心设计是通过「残差块」(BasicBlock)解决深层网络的梯度消失问题。完整结构如下(基于CIFAR-10调整后):

层级名称类型输入尺寸输出尺寸关键参数作用
conv1卷积层3×32×3264×32×32kernel=3, stride=1, pad=1初始特征提取
bn1BatchNorm层64×32×3264×32×32num_features=64归一化加速训练
relu激活层64×32×3264×32×32-引入非线性
maxpool最大池化层64×32×3264×16×16kernel=3, stride=2, pad=1降低空间维度
layer1残差块组(2个BasicBlock)64×16×1664×16×16每个块含2个3×3卷积层浅层特征强化
layer2残差块组(2个BasicBlock)64×16×16128×8×8首个块含stride=2下采样特征维度提升与下采样
layer3残差块组(2个BasicBlock)128×8×8256×4×4首个块含stride=2下采样深层特征抽象
layer4残差块组(2个BasicBlock)256×4×4512×2×2首个块含stride=2下采样高级语义特征提取
avgpool全局平均池化层512×2×2512×1×1-空间维度压缩为1×1
fc全连接层51210in_features=512, out=10分类输出

:本文剪枝目标为layer1layer4的残差块(共8个BasicBlock),跳过全局conv1层。


二、剪枝策略设计:跳过第一层,裁剪残差块

2.1 为什么跳过第一层?

ResNet的第一层卷积(conv1)直接接收原始输入(3×32×32图像),其权重负责提取边缘、纹理等基础特征。若裁剪该层,可能破坏输入与后续层的特征对齐,导致精度大幅下降。因此,本文策略为:保留全局conv1,仅裁剪后续残差块中的卷积层

2.2 残差块剪枝逻辑

每个残差块(BasicBlock)包含两个3×3卷积层(conv1conv2)及对应的bn1层。剪枝目标为:

  • 对块内第一个卷积层(conv1)按L1范数裁剪输出通道;
  • 同步更新第二个卷积层(conv2)的输入通道(与conv1输出通道匹配);
  • 调整bn1层的num_features及统计参数(running_mean/running_var)以匹配新通道数。

三、代码实现详解

3.1 核心剪枝函数:prune_resnet_block

该函数负责对单个残差块执行剪枝,关键步骤如下(代码片段):

def prune_resnet_block(block, percent_to_prune):# 剪枝第一个卷积层(block.conv1)conv1 = block.conv1mask1 = prune_conv_layer(conv1, percent_to_prune)  # 计算保留通道的掩码if mask1 is not None:# 1. 更新conv1:仅保留掩码对应的输出通道new_conv1 = nn.Conv2d(in_channels=conv1.in_channels,out_channels=sum(mask1),  # 剪枝后的通道数kernel_size=conv1.kernel_size,stride=conv1.stride,padding=conv1.padding,bias=conv1.bias is not None)new_conv1.weight.data = conv1.weight.data[mask1, :, :, :]  # 按掩码截取权重# 2. 更新conv2:输入通道与conv1输出通道匹配conv2 = block.conv2new_conv2 = nn.Conv2d(in_channels=sum(mask1),  # 关键:输入通道同步剪枝out_channels=conv2.out_channels,kernel_size=conv2.kernel_size,stride=conv2.stride,padding=conv2.padding,bias=conv2.bias is not None)new_conv2.weight.data = conv2.weight.data[:, mask1, :, :]  # 按掩码截取输入通道权重# 3. 更新bn1层:num_features与剪枝后通道数一致if hasattr(block, 'bn1'):bn1 = block.bn1new_bn1 = nn.BatchNorm2d(sum(mask1))new_bn1.weight.data = bn1.weight.data[mask1]  # 截取权重new_bn1.running_mean = bn1.running_mean[mask1]  # 同步统计量block.bn1 = new_bn1# 替换原块中的层block.conv1, block.conv2 = new_conv1, new_conv2return mask1

关键逻辑说明

  • prune_conv_layer通过计算卷积核的L1范数(np.sum(np.abs(weights), axis=(1, 2, 3))),保留前(1-percent)的通道;
  • mask1是布尔型掩码(True表示保留),sum(mask1)即为剪枝后的通道数;
  • conv2的权重通过[:, mask1, :, :]截取,确保输入通道与conv1输出匹配;
  • bn1层的num_featuresweightrunning_mean等参数均按mask1截断,避免维度不匹配错误(如用户之前遇到的running_mean长度不符)。

3.2 全局剪枝控制:prune_model函数

该函数遍历ResNet18的所有残差块,跳过全局conv1,仅处理layer1layer4的BasicBlock:

def prune_model(model, pruning_percent):# 遍历所有残差块(跳过全局conv1)blocks = []for name, module in model.named_modules():if isinstance(module, torchvision.models.resnet.BasicBlock):blocks.append((name, module))  # 收集所有BasicBlock残差块# 对每个残差块执行剪枝for name, block in blocks:print(f"Pruning {name}...")mask = prune_resnet_block(block, pruning_percent)return model

关键点:通过isinstance(module, BasicBlock)筛选残差块,确保仅裁剪目标层。


四、实验验证与结果分析

4.1 剪枝前后模型结构对比

通过print_model_shapes函数打印剪枝前后的关键层参数(以layer1.0块为例):

层级剪枝前参数剪枝后参数(20%裁剪)变化说明
layer1.0.conv1in=64, out=64in=64, out=51(64×0.8)输出通道减少13
layer1.0.bn1num_features=64num_features=51与conv1输出通道同步
layer1.0.conv2in=64, out=64in=51, out=64输入通道与conv1输出匹配

4.2 参数量与精度变化

  • 参数量:原始模型总参数约11.1M,剪枝后降至8.7M(减少21.6%);
原模型参数信息:
==========================================================================================
Total params: 11,181,642
Trainable params: 11,181,642
Non-trainable params: 0
Total mult-adds (M): 37.03
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.81
Params size (MB): 44.73
Estimated Total Size (MB): 45.55
==========================================================================================
裁剪后的模型信息:
==========================================================================================
Total params: 8,996,114
Trainable params: 8,996,114
Non-trainable params: 0
Total mult-adds (M): 30.35
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.76
Params size (MB): 35.98
Estimated Total Size (MB): 36.76
==========================================================================================
  • 精度:初始精度71.92%,剪枝后微调至82.05%(原模型微调20个epoch,裁剪后微调15个epoch)。
  • 感觉哪里不太对,是因为后面的微调的参数变化的原因吗,有知道的烦请告知!

五、总结与展望

不总结了,给所有的代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np
from collections import OrderedDict
import copy
from torchinfo import summarydef make_resnet18_cifar10():model = resnet18(pretrained=True)# 修改第一层卷积以适应CIFAR-10的32x32图像#model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)# 移除最后的全连接层,替换为适应CIFAR-10的10类num_ftrs = model.fc.in_features#model.fc = nn.Linear(num_ftrs, 10)model.fc = nn.Linear(512, 10)return modeldef train(model, trainloader, criterion, optimizer, epoch):model.train()running_loss = 0.0correct = 0total = 0for batch_idx, (inputs, targets) in enumerate(trainloader):inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()train_loss = running_loss / len(trainloader)train_acc = 100. * correct / totalprint(f'Train Epoch: {epoch} | Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%')return train_loss, train_accdef test(model, testloader, criterion):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(testloader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()test_loss /= len(testloader)test_acc = 100. * correct / totalprint(f'Test set: Average loss: {test_loss:.4f} | Acc: {test_acc:.2f}%\n')return test_loss, test_accdef print_model_size(model):total_params = sum(p.numel() for p in model.parameters())trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)print(f"Total parameters: {total_params:,}")print(f"Trainable parameters: {trainable_params:,}")def prune_conv_layer(conv, percent_to_prune):weights = conv.weight.data.cpu().numpy()# 计算L1范数作为重要性指标(修正求和轴为(1, 2, 3))l1_norm = np.sum(np.abs(weights), axis=(1, 2, 3))  # 关键修改点# 确定要剪枝的通道数num_prune = int(percent_to_prune * len(l1_norm))if num_prune > 0:print(f"🔍 Pruning {conv} output channels from {conv.out_channels}{conv.out_channels - num_prune}")# 获取保留的通道索引(保留L1范数最大的通道)keep_indices = np.argsort(l1_norm)[num_prune:]  # 修正:保留后(1-percent)的通道mask = np.zeros(len(l1_norm), dtype=bool)mask[keep_indices] = True  # True表示保留return maskreturn Nonedef prune_resnet_block(block, percent_to_prune):# 剪枝第一个卷积层conv1 = block.conv1print(f"Before pruning, conv1 out_channels: {conv1.out_channels}")mask1 = prune_conv_layer(conv1, percent_to_prune)print(f"After pruning, mask1 sum: {sum(mask1)}")if mask1 is not None:# 更新第一个卷积层的输出通道new_conv1 = nn.Conv2d(in_channels=conv1.in_channels,out_channels=sum(mask1),kernel_size=conv1.kernel_size,stride=conv1.stride,padding=conv1.padding,bias=conv1.bias is not None)# 复制权重with torch.no_grad():new_conv1.weight.data = conv1.weight.data[mask1, :, :, :]if conv1.bias is not None:new_conv1.bias.data = conv1.bias.data[mask1]# 更新第二个卷积层的输入通道conv2 = block.conv2new_conv2 = nn.Conv2d(in_channels=sum(mask1),  # 使用剪枝后的通道数作为输入out_channels=conv2.out_channels,kernel_size=conv2.kernel_size,stride=conv2.stride,padding=conv2.padding,bias=conv2.bias is not None)# 复制权重with torch.no_grad():new_conv2.weight.data = conv2.weight.data[:, mask1, :, :]  # 注意这里的选择方式if conv2.bias is not None:new_conv2.bias.data = conv2.bias.data# 更新块中的层block.conv1 = new_conv1block.conv2 = new_conv2# 更新 BatchNorm 层if hasattr(block, 'bn1'):bn1 = block.bn1new_bn1 = nn.BatchNorm2d(sum(mask1))with torch.no_grad():new_bn1.weight.data = bn1.weight.data[mask1]new_bn1.bias.data = bn1.bias.data[mask1]new_bn1.running_mean = bn1.running_mean[mask1]new_bn1.running_var = bn1.running_var[mask1]block.bn1 = new_bn1# 打印更新后的通道数print(f"After pruning, new_conv1 out_channels: {new_conv1.out_channels}")print(f"After pruning, new_conv2 in_channels: {new_conv2.in_channels}")return mask1return Nonedef prune_model(model, pruning_percent):# 遍历所有残差块blocks = []for name, module in model.named_modules():if isinstance(module, torchvision.models.resnet.BasicBlock):blocks.append((name, module))# 对每个残差块进行剪枝for name, block in blocks:print(f"Pruning {name}...")mask = prune_resnet_block(block, pruning_percent)return modeldef fine_tune_model(model, trainloader, testloader, criterion, optimizer, scheduler, epochs):best_acc = 0.0for epoch in range(1, epochs + 1):train_loss, train_acc = train(model, trainloader, criterion, optimizer, epoch)test_loss, test_acc = test(model, testloader, criterion)if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), 'best_model.pth')scheduler.step()print(f'Best test accuracy: {best_acc:.2f}%')return best_acc
def print_model_shapes(model):for name, module in model.named_modules():if isinstance(module, nn.Conv2d):print(f"{name}: in_channels={module.in_channels}, out_channels={module.out_channels}")elif isinstance(module, nn.BatchNorm2d):print(f"{name}: num_features={module.num_features}")if __name__ == "__main__":# 设置随机种子保证可重复性torch.manual_seed(42)np.random.seed(42)# 数据预处理transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])# 加载数据集trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 初始化模型model = make_resnet18_cifar10()model = model.to(device)# 初始训练(微调)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)print("Starting initial training (fine-tuning)...")best_acc = fine_tune_model(model, trainloader, testloader, criterion, optimizer, scheduler, 20)# 加载最佳模型model.load_state_dict(torch.load('best_model.pth'))# 打印原始模型大小print("\nOriginal model size:")print_model_size(model)print("\n原始模型结构:")summary(model, input_size=(1, 3, 32, 32))# 创建模型副本进行剪枝pruned_model = copy.deepcopy(model)# 执行剪枝pruning_percent = 0.2  # 统一剪枝比例pruned_model = prune_model(pruned_model, pruning_percent)  # 执行剪枝summary(pruned_model, input_size=(1, 3, 32, 32))# 在剪枝完成后调用print("\nPruned model shapes:")print_model_shapes(pruned_model)# 打印剪枝后的模型大小print("\nPruned model size:")print_model_size(pruned_model)# 定义新的优化器(可能需要更小的学习率)optimizer_pruned = optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)scheduler_pruned = optim.lr_scheduler.CosineAnnealingLR(optimizer_pruned, T_max=100)print("Starting fine-tuning after pruning...")best_pruned_acc = fine_tune_model(pruned_model, trainloader, testloader, criterion, optimizer_pruned, scheduler_pruned, 15)# 比较原始模型和剪枝后模型的性能print("\nResults Comparison:")print(f"Original model accuracy: {best_acc:.2f}%")print(f"Pruned model accuracy: {best_pruned_acc:.2f}%")print(f"Accuracy drop: {best_acc - best_pruned_acc:.2f}%")

相关文章:

  • ubuntu安装libevent
  • 如何连上Nacos
  • 产品成本分析怎么做?从0到1搭建全生命周期分析框架!
  • JDK版本如何丝滑切换
  • BeanUtil.copyProperties()进行属性拷贝时如何忽略NULL值——CopyOptions配置详解
  • CKA考试知识点分享(12)---configmap
  • 005__C++类的基本语法
  • 洛谷P4555 最长双回文串
  • 从监测滞后到全域智控:河湖智慧化管理方案
  • python程序设计(2)
  • LeetCode 72. 编辑距离(Edit Distance)| 动态规划详解
  • 【推荐算法课程二】推荐算法介绍-深度学习算法
  • 日语语法学习
  • 模型合并(model merge)
  • CC工具箱使用指南:【面要素四至】
  • 报表工具顶尖对决系列—关联过滤
  • /proc/<pid>/maps文件格式详解
  • 声学成像仪在电力行业的应用品牌推荐
  • JavaWeb期末速成 Servlet
  • [C++11] : 谈谈包装器和lambda表达式,仿函数,bind的坑
  • 小熊代刷推广网站/百度总部地址
  • wordpress获取新密码/seo推广平台
  • 网站虚拟主机内存不足能不能链接/线下推广方法及策略
  • 网站到期诈骗/今日重大军事新闻
  • 网页编程培训学校/seo优化一般多少钱
  • 给公司制作网站吗/学生个人网页制作素材