当前位置: 首页 > news >正文

【免费可用】【提供源代码】对YOLOV11模型进行剪枝和蒸馏

yolov11_prune_distillation

该项目可以用于YOLOv11网络的训练,静态剪枝和知识蒸馏。可以在减少模型参数量的同时,尽量保证模型的推理精度。

Github链接:https://github.com/zhahoi/yolov11_prune_distillation.git

🤗Current Ultralytics version: 8.3.160

🔧 Install Dependencies

pip install torch-pruning 
pip install -r requirements.txt

🚂 Training & Pruning & Knowledge Distillation

📊 YOLO11 Training Example

### train.py
from ultralytics import YOLOif __name__ == "__main__":model = YOLO('yolo11.yaml')results = model.train(data='uno.yaml', epochs=100, imgsz=640, batch=8, device="0", name='yolo11', workers=0, prune=False)

✂️ YOLO11 Pruning Example

### prune.py
from ultralytics import YOLO# model = YOLO('yolo11.yaml')
model = YOLO('runs/detect/yolo11/weights/best.pt')def prunetrain(train_epochs, prune_epochs=0, quick_pruning=True, prune_ratio=0.5, prune_iterative_steps=1, data='coco.yaml', name='yolo11', imgsz=640, batch=8, device=[0], sparse_training=False):if not quick_pruning:assert train_epochs > 0 and prune_epochs > 0, "Quick Pruning is not set. prune epochs must > 0."print("Phase 1: Normal training...")model.train(data=data, epochs=train_epochs, imgsz=imgsz, batch=batch, device=device, name=f"{name}_phase1", prune=False,sparse_training=sparse_training)print("Phase 2: Pruning training...")best_weights = f"runs/detect/{name}_phase1/weights/best.pt"pruned_model = YOLO(best_weights)return pruned_model.train(data=data, epochs=prune_epochs, imgsz=imgsz, batch=batch, device=device, name=f"{name}_pruned", prune=True,prune_ratio=prune_ratio, prune_iterative_steps=prune_iterative_steps)else:return model.train(data=data, epochs=train_epochs, imgsz=imgsz, batch=batch, device=device, name=name, prune=True, prune_ratio=prune_ratio, prune_iterative_steps=prune_iterative_steps)if __name__ == '__main__':# Normal Pruningprunetrain(quick_pruning=False,       # Quick Pruning or notdata='uno.yaml',          # Dataset configtrain_epochs=10,           # Epochs before pruningprune_epochs=20,           # Epochs after pruning imgsz=640,                 # Input sizebatch=8,                   # Batch sizedevice=[0],                # GPU devicesname='yolo11_prune',             # Save nameprune_ratio=0.5,           # Pruning Ratio (50%)prune_iterative_steps=1,   # Pruning Interative Stepssparse_training=True      # Experimental, Allow Sparse Training Before Pruning)# Quick Pruning (prune_epochs no need)# prunetrain(quick_pruning=True, data='coco.yaml', train_epochs=10, imgsz=640, batch=8, device=[0], name='yolo11', #            prune_ratio=0.5, prune_iterative_steps=1)

🔎 YOLO11 Knowledge Distillation Example

### knowledge_distillation.py
from ultralytics import YOLO
from ultralytics.nn.attention.attention import ParallelPolarizedSelfAttention
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils.torch_utils import model_infodef add_attention(model):at0 = model.model.model[4]n0 = at0.cv2.conv.out_channelsat0.attention = ParallelPolarizedSelfAttention(n0)at1 = model.model.model[6]n1 = at1.cv2.conv.out_channelsat1.attention = ParallelPolarizedSelfAttention(n1)at2 = model.model.model[8]n2 = at2.cv2.conv.out_channelsat2.attention = ParallelPolarizedSelfAttention(n2)return modelif __name__ == "__main__":# layers = ["6", "8", "13", "16", "19", "22"]layers = ["4", "6", "10", "16", "19", "22"]model_t = YOLO('runs/detect/yolo11/weights/best.pt')  # the teacher modelmodel_s = YOLO("runs/detect/yolo11_prune_pruned/weights/best.pt")  # the student modelmodel_s = add_attention(model_s) # Add attention to the student model# configure overridesoverrides = {"model": "runs/detect/yolo11_prune_pruned/weights/best.pt","Distillation": model_t.model,"loss_type": "mgd","layers": layers,"epochs": 50,"imgsz": 640,"batch": 8,"device": 0,"lr0": 0.001,"amp": False,"sparse_training": False,"prune": False,"prune_load": False,"workers": 0,"data": "data.yaml","name": "yolo11_distill"}trainer = DetectionTrainer(overrides=overrides)trainer.model = model_s.model model_info(trainer.model, verbose=True)trainer.train()

📤 Model Export

Export to ONNX Format Example

### export.py
from ultralytics import YOLOmodel = YOLO('runs/detect/yolo11_distill/weights/yolo11n.pt')
print(model.model)
model.export(format='onnx')

🌞 Model Inference

Image Inference Example

### infer.py
from ultralytics import YOLO
model = YOLO('runs/detect/yolo11/weights/best.pt') # model = YOLO('prune.pt')
model.predict('fruits.jpg', save=True, device=[0], line_width=2)

🔢 Model Analysis

Use thop to easily calculate model parameters and FLOPs:

pip install thop

You can calculate model parameters and flops by using calculate.py

🤝 Contributing & Support

Feel free to submit issues or pull requests on GitHub for questions or suggestions!

📚 Acknowledgements

  • Special thanks to @VainF for the contribution to the Torch-Pruning project! This project relies on it for model pruning.
  • Special thanks to @Ultralytics for the contribution to the ultralytics project! This project relies on it for the framework.
  • YOLO-Pruning-RKNN
  • yolov11_prune_distillation_v2
http://www.dtcms.com/a/301950.html

相关文章:

  • Excel常用函数大全,非常实用
  • 重构vite.config.json
  • Jenkins vs GitLab CI/CD vs GitHub Actions在容器化部署流水线中的对比分析与实践指南
  • 云原生MySQL Operator开发实战(三):高级特性与生产就绪功能
  • CodeBuddy的安装教程
  • 优测推出HarmonyOS全场景测试服务,解锁分布式场景应用卓越品质!
  • 表征学习:机器认知世界的核心能力与前沿突破
  • 「源力觉醒 创作者计划」_文心大模型4.5系列开源模型,意味着什么?对开发者、对行业生态有何影响?
  • 新能源行业B端极简设计:碳中和目标下的交互轻量化实践
  • C#与C++交互开发系列(二十六):构建跨语言共享缓存,实现键值对读写与数据同步(实践方案)
  • 电子电路原理学习笔记---第4章二极管电路---第3天
  • 墨者:SQL注入实战-MySQL
  • uni-datetime-picker兼容ios
  • 【iOS】类和分类的加载过程
  • MySQL有哪些“饮鸩止渴”提高性能的方法?
  • 【Linux篇章】穿越数据迷雾:HTTPS构筑网络安全的量子级护盾,重塑数字信任帝国!
  • 全面解析MySQL(4)——三大范式与联合查询实例教程
  • 【Java Web实战】从零到一打造企业级网上购书网站系统 | 完整开发实录(终)
  • Linux DNS解析2 -- 网关DNS代理的作用
  • CodeMeter授权管理方案助力 PlantStream 引领工业设计新变革
  • 接口测试怎么做?接口测试工具有哪些?
  • JavaWeb 入门:HTML 基础与实战详解(Java 开发者视角)
  • 使用JavaScript实现一个代办事项的小案例
  • 基于亮数据 MCP 的 Trae 智能体,让规模化 Google 数据实时分析触手可及
  • MCP资源管理深度实践:动态数据源集成方案
  • 剑指“CPU飙高”问题
  • 从视觉到智能:RTSP|RTMP推拉流模块如何助力“边缘AI系统”的闭环协同?
  • Entity Framework Core (EF Core) 中状态检测
  • 编程算法:技术创新的引擎与业务增长的核心驱动力
  • 【前端】Tab切换时的数据重置与加载策略技术文档