当前位置: 首页 > news >正文

yolov8通道级剪枝讲解(超详细思考版)

为了提升推理速度并降低部署成本,模型剪枝已成为关键技术。本文将结合实践操作,讲解YOLOv8模型剪枝的方法原理、实施步骤及注意事项。

虽然YOLOv8n版本本身参数量少、推理速度快,能满足大多数工业检测需求,但谷歌研究表明:通过对大模型进行裁剪得到的小模型往往性能更优。

本文基于其他博客的剪枝方法的代码实现,专门针对YOLOv8模型进行剪枝优化,能够理解模型剪枝的底层操作。其核心创新点在于利用BN层(Batch Normalization)的特性,实现高效的通道级剪枝操作。

一、剪枝的理论基础

  • BN参数的重要性:BN层中的缩放参数(γ)代表了卷积核的重要程度,通过裁剪γ值较小的卷积核,可以实现剪枝。
  • 剪枝流程总体架构
    1. 训练稀疏模型(引入BN正则化)
    2. 计算剪枝阈值
    3. 剪除冗余卷积核
    4. 微调模型,恢复性能

二、YOLOv8剪枝的具体步骤

1. 预备工作

  • 模型训练: 先进行完整训练,获得基准性能指标。
  • 将LL_pruning.pyLL_train.py这两个文件放在根目录下

    LL_train.py代码如下所示:
    from ultralytics import YOLO  # 导入YOLO模型库  
    import os  # 导入os模块,用于处理文件路径  root = os.getcwd()  # 获取当前工作目录  ## 配置文件路径  
    name_yaml = os.path.join(root, "ultralytics/datasets/VOC.yaml")  # 数据集配置文件路径  
    name_pretrain = os.path.join(root, r"D:\practice_demo\ultralytics\runs\detect\jueyuanzi_yolov8m\best.pt")  # 预训练模型路径  ## 原始训练路径  
    path_train = os.path.join(root, "runs/detect/VOC")  # 原始训练结果保存路径  
    name_train = os.path.join(path_train, "weights/last.pt")  # 原始训练模型文件路径  ## 约束训练路径、剪枝模型文件  
    path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")  # 约束训练结果保存路径  
    name_prune_before = os.path.join(path_constraint_train, "weights/last.pt")  # 剪枝前模型文件路径  
    name_prune_after = os.path.join(path_constraint_train, "weights/last_prune.pt")  # 剪枝后模型文件路径  ## 微调路径  
    path_fineturn = os.path.join(root, "runs/detect/VOC_finetune")  # 微调结果保存路径  def step1_train():  model = YOLO(name_pretrain)  # 加载预训练模型  model.train(data=name_yaml, imgsz=640, epochs=300, batch=32, name=path_train)  # 训练模型  ## 一定要添加【amp=False】  
    def step2_Constraint_train():  model = YOLO(name_train)  # 加载原始训练模型  model.train(data=name_yaml, imgsz=640, epochs=50, batch=32, amp=False, save_period=1, name=path_constraint_train)  # 训练模型  def step3_pruning():  from LL_pruning import do_pruning  # 导入剪枝函数  do_pruning(name_prune_before, name_prune_after)  # 执行剪枝操作  def step4_finetune():  model = YOLO(name_prune_after)  # 加载剪枝后的模型  model.train(data=name_yaml, imgsz=640, epochs=100, batch=32, save_period=1, name=path_fineturn)  # 微调模型  # 执行训练、约束训练、剪枝和微调步骤  
    step1_train()  # 训练模型  
    # step2_Constraint_train()  # 进行稀疏训练  
    # step3_pruning()  # 执行剪枝  
    # step4_finetune()  # 微调模型

LL_pruning.py代码如下所示:

​
from ultralytics import YOLO  # 导入YOLO模型
import torch  # 导入PyTorch库
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect  # 导入YOLO模型中的模块
import os  # 导入os模块,用于处理文件路径# os.environ["CUDA_VISIBLE_DEVICES"] = "2"  # 可选:指定使用的GPU设备class PRUNE():def __init__(self) -> None:self.threshold = None  # 初始化阈值def get_threshold(self, model, factor=0.8):"""计算剪枝阈值:param model: YOLO模型:param factor: 剪枝比例,默认0.8"""ws = []  # 存储权重bs = []  # 存储偏置for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):  # 仅处理BatchNorm2d层w = m.weight.abs().detach()  # 获取权重的绝对值b = m.bias.abs().detach()  # 获取偏置的绝对值ws.append(w)  # 添加权重bs.append(b)  # 添加偏置print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())  # 打印权重和偏置的最大最小值# 合并所有权重ws = torch.cat(ws)# 计算剪枝阈值self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):"""对卷积层的“相邻”卷积做通道级剪枝。参数----:param conv1: 第一个卷积层: Conv(Ultralytics封装的Conv模块,内部含 nn.Conv2d + BN + 激活)*上游* 被剪枝的卷积。删除它的某些 输出 通道。:param conv2: 第二个卷积层: Conv 或 Conv列表 / 纯 nn.Conv2d / None*下游* 接收 conv1 输出的卷积(可能有多支分支)。需要把 输入 通道同步删除。剪枝规则--------1. 用 conv1 中 BatchNorm 的缩放系数 γ 的绝对值做“重要性”指标。2. 选出 |γ| >= 全局阈值 的通道索引 keep_idxs(若太少则降低阈值,至少保留8个,防止结构非法)。3. 在 conv1 中:删掉其它通道 → 需要同时修改 BN 的各种统计量与 nn.Conv2d 的权重/偏置/out_channels。4. 在 conv2 中:这些被删的只是“输入特征图”,因此只更新 in_channels。"""# a. 根据BN中的参数,获取需要保留的indexgamma = conv1.bn.weight.data.detach()  # 获取BN层的权重beta = conv1.bn.bias.data.detach()  # 获取BN层的偏置keep_idxs = []  # 存储需要保留的索引local_threshold = self.threshold  # 使用全局阈值while len(keep_idxs) < 8:  # 确保至少保留8个卷积核keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]  # 获取满足条件的索引local_threshold = local_threshold * 0.5  # 如果不足8个,降低阈值n = len(keep_idxs)  # 保留的卷积核数量print(n / len(gamma))  # 打印保留的比例# b. 利用index对BN进行剪枝conv1.bn.weight.data = gamma[keep_idxs]  # 更新BN权重conv1.bn.bias.data = beta[keep_idxs]  # 更新BN偏置conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]  # 更新BN的方差conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]  # 更新BN的均值conv1.bn.num_features = n  # 更新BN的特征数量conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]  # 更新卷积层的权重conv1.conv.out_channels = n  # 更新卷积层的输出通道数# c. 利用index对conv1进行剪枝if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]  # 更新卷积层的偏置# d. 利用index对conv2进行剪枝if not isinstance(conv2, list):conv2 = [conv2]  # 确保conv2是列表for item in conv2:if item is None: continue  # 跳过Noneif isinstance(item, Conv):conv = item.conv  # 获取卷积层else:conv = itemconv.in_channels = n  # 更新输入通道数conv.weight.data = conv.weight.data[:, keep_idxs]  # 更新卷积层的权重def prune(self, m1, m2):"""对模块进行剪枝:param m1: 第一个模块:param m2: 第二个模块"""if isinstance(m1, C2f):  # 如果m1是C2f模块,获取其cv2m1 = m1.cv2if not isinstance(m2, list):  # 确保m2是列表m2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1  # 获取C2f或SPPF的cv1self.prune_conv(m1, m2)  # 对卷积层进行剪枝def do_pruning(modelpath, savepath):"""执行剪枝操作:param modelpath: 原始模型路径:param savepath: 剪枝后模型保存路径"""pruning = PRUNE()  # 创建PRUNE实例### 0. 加载模型yolo = YOLO(modelpath)  # 从指定路径加载YOLO模型pruning.get_threshold(yolo.model, 0.8)  # 获取剪枝阈值,0.8为剪枝率### 1. 剪枝c2f中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck):  # 仅处理Bottleneck模块pruning.prune_conv(m.cv1, m.cv2)  # 对Bottleneck中的卷积层进行剪枝### 2. 指定剪枝不同模块之间的卷积核seq = yolo.model.model  # 获取模型的序列for i in [3, 5, 7, 8]:  # 指定需要剪枝的模块pruning.prune(seq[i], seq[i + 1])  # 对相邻模块进行剪枝### 3. 对检测头进行剪枝detect: Detect = seq[-1]  # 获取检测头last_inputs = [seq[15], seq[18], seq[21]]  # 获取最后输入的模块colasts = [seq[16], seq[19], None]  # 获取与最后输入相连的模块for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):pruning.prune(last_input, [colast, cv2[0], cv3[0]])  # 对输入模块和检测头进行剪枝pruning.prune(cv2[0], cv2[1])  # 对检测头的卷积层进行剪枝pruning.prune(cv2[1], cv2[2])  # 对检测头的卷积层进行剪枝pruning.prune(cv3[0], cv3[1])  # 对检测头的卷积层进行剪枝pruning.prune(cv3[1], cv3[2])  # 对检测头的卷积层进行剪枝### 4. 模型梯度设置与保存for name, p in yolo.model.named_parameters():p.requires_grad = True  # 设置所有参数的梯度为可计算# yolo.val()  # 验证模型性能torch.save(yolo.ckpt, savepath)  # 保存剪枝后的模型yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))  # 更新模型路径yolo.export(format="onnx")  # 导出为ONNX格式## 重新加载模型,修改保存命名,用以比较剪枝前后的onnx的大小yolo = YOLO(modelpath)  # 从指定路径加载YOLO模型yolo.export(format="onnx")  # 导出为ONNX格式if __name__ == "__main__":modelpath = "runs/detect1/14_Constraint/weights/last.pt"  # 原始模型路径savepath = "runs/detect1/14_Constraint/weights/last_prune.pt"  # 剪枝后模型保存路径do_pruning(modelpath, savepath)  # 执行剪枝操作​

2. 稀疏正则训练

  • 使用带有 BN正则的训练方式,促进BN参数稀疏化。

首先加载一个正常训练的yolov8模型权重(.pt文件),ultralytics/engine/trainer.py中添加如下代码,使得bn参数在训练时变得稀疏。

代码中对所有 BatchNorm 层加了 L1 正则,以便自动把不重要的通道“压”成零,后面再统一按阈值剪枝。关键代码如下:

...## add start=============================## add l1 regulation for step2_Constraint_trainl1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))## add end ==============================...
  • 为什么只对 BN 做正则?
    BatchNorm 的 γ(scale)系数直接影响通道输出强度:γ ≈ 0 时,该通道几乎不参与后续计算,用它来衡量“重要性”最直观。

  • L1 正则如何“稀疏”?
    在反向传播时,为每个 γ/β 的梯度额外加上 ±λ,这会让本就小的 γ 更快被拉向 0,从而在训练中自然分化出大 γ(保留通道)和小 γ(待剪通道)。

  • λ 为何随 epoch 递减?
    训练初期靠强正则快速分离;后期减弱正则,避免过度压榨保留通道,给微调留下空间。

  • bias 也正则吗?
    虽然偏置对通道筛选作用不如 γ 强,但适度收敛 β 能进一步去除边缘特征,提高稀疏度。

之后在LL_pruning.py中运行方框中的代码

注意事项:

稀疏训练需要关闭混合精度(amp=False
剪枝依赖于 BatchNorm 的 γ 值作为排序阈值,γ 越小越容易被剪除。若使用 FP16(混合精度),许多接近 0 的 γ 会被量化到同一值甚至下溢为 0,导致排序失真,同时 L1 正则梯度也容易消失,后续剪枝的阈值选择会变得不稳定。而使用 FP32(amp=False)能精确表示这些微小差异,确保稀疏模式可控。

稀疏训练的 batch size 不宜过大
由于关闭了混合精度,模型采用全精度计算,显存占用显著增加。若 batch size 设置过大,可能导致显存溢出(OOM),进而引发训练失败。

稀疏训练阶段要将 patience 设为 0 或较大值
稀疏训练的目标并非短期提升 mAP,而是让 BN 的 γ 在多个 epoch 内逐步被 L1 正则“压缩”。在此期间,验证集指标可能停滞甚至下降。若启用常规早停机制(默认 patience 为几十),训练可能在 γ 尚未充分分化前被提前终止,导致剪枝时阈值模糊、可剪通道不足。

3. 剪枝

执行以下代码;

剪枝中的注意点:

在 YOLOv8 中,当进行 split concat 操作时,若剪枝后的通道数不匹配会报错。LL_pruning.py 的剪枝代码怎么避免这一问题,暂时还没研究透,有大佬知道请不吝指教。

关于 do_pruning 方法启用 yolo.val() 后保存的剪枝模型缺失 BN 层的原因:
Ultralytics 的验证 / 导出流程会将 Conv + BatchNorm 静态融合到卷积权重和偏置中,从而提升推理速度和轻量化。这一过程会直接移除 BN 层,因此保存的 yolo.ckpt 是已融合的模型。

对比剪枝前后的模型文件(last.pt/last_prune.pt)及其 ONNX 转换结果:
剪枝后的 .pt 文件增大,而 ONNX 文件从 43MB 缩减至 36MB。这是因为 .pt 文件包含完整的 checkpoint 元数据,而 ONNX 仅保存精简的推理图结构,因此只需关注 ONNX 文件大小的优化即可。

4. 微调

在第二步稀疏正则训练中将BN约束注释

需要注意的是明明加载的是剪枝后的模型,但训练启动时打印的日志却显示为标准版模型的参数。并且经过验证,微调后的模型参数就是标准的yolo模型。所以需要进行一些修改,详细的讲解可以看YOLOv8 剪枝模型加载踩坑记:解决 YAML 覆盖剪枝结构的问题-CSDN博客

修改ultralytics/engine/model.py文件内容:
self.trainer.model包含从YAML文件加载的原始模型配置信息,以及从PT文件加载的剪枝后权重。只需将该变量的网络结构更新为剪枝后的网络结构就行,否则训练后的模型参数不会改变。

运行下面的代码

yolov8模型的剪枝到这就结束了。

http://www.dtcms.com/a/297676.html

相关文章:

  • 解密负载均衡:如何轻松提升业务性能
  • JS事件流
  • 疯狂星期四第19天运营日记
  • 网络资源模板--基于Android Studio 实现的天气预报App
  • LeetCode 127:单词接龙
  • 三维图像识别中OpenCV、PCL和Open3D结合的主要技术概念、部分示例
  • 水库大坝安全监测的主要内容
  • MySQL 全新安装步骤(Linux版yum源安装)
  • Lua(面向对象)
  • 深度学习水论文:特征提取
  • NBIOT模块 BC28通过MQTT协议连接到EMQX
  • 如何在 Ubuntu 24.04 或 22.04 上安装和使用 GDebi
  • 智能网关:物联网时代的核心枢纽
  • ABP VNext + Razor 邮件模板:动态、多租户隔离、可版本化的邮件与通知系统
  • 智能网关芯片:物联网连接的核心引擎
  • 酷暑来袭,科技如何让城市清凉又洁净?
  • 制造业低代码平台实战评测:简道云、钉钉宜搭、华为云Astro、金蝶云·苍穹、斑斑低代码,谁更值得选?
  • 使用 FFmpeg 实现 RTP 音频传输与播放
  • 【Redis】初识Redis(定义、特征、使用场景)
  • Spring框架
  • 认识编程(3)-语法背后的认知战争:类型声明的前世今生
  • vue3单页面连接多个websocket并实现断线重连功能
  • 机器学习笔记(三)——决策树、随机森林
  • Git指令
  • git将本地文件完和仓库文件目录完全替换-------还有将本地更新的文件放到仓库中,直接提交即可
  • C# WPF 实现读取文件夹中的PDF并显示其页数
  • STM32与ADS1220多通道采样数据
  • vscode 登录ssh记住密码直接登录设置
  • GPU 服务器ecc报错处理
  • 详谈OSI七层模型和TCP/IP四层模型以及tcp与udp为什么是4层,http与https为什么是7层