当前位置: 首页 > news >正文

【Mac-ML-DL】深度学习使用MPS出现内存泄露(leaked semaphore)以及张量转换错误

MPS加速修改总结

先说设备:MacBook Pro M4 24GB
事情的起因是我在进行深度学习的时候想尝试用苹果自带的MPS进行训练加速,修改设备后准备开始训练,但是出现如下报错:

UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdownwarnings.warn('resource_tracker: There appear to be %d '

我尝试在stackoverflow上面寻找答案,但是只有人提问,没有人回答,于是我进入PyTorch的社区进行查找,终于有人也提了这个问题。
修改后没有出现内存泄露的问题,但是有新的问题:

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

幸好有人直接给出了GitHub的issue链接,根据官方的建议修改后成功运行了,接下来分享我的修改全流程,帮助大家避坑。

1. 设备检测与切换

  • train.py中添加了对MPS设备的检测和使用:

    if torch.backends.mps.is_available() and torch.backends.mps.is_built():device = torch.device("mps")
    
  • 添加命令行参数支持直接指定设备类型:--device mps

2. 数据类型修复

  • 创建DoubleToFloatTransform转换器确保所有张量为float32类型,因为MPS不支持float64

  • 在数据转换pipeline中添加此转换器:

    transforms.Compose([# 其他转换...transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),DoubleToFloatTransform()  # 确保张量为float32类型
    ])
    
  • 将代码中所有.double()替换为.float(),避免精度转换错误

3. 内存管理优化

  • 使用multiprocessing.set_start_method('spawn')解决MPS设备上的bus error问题

  • 定期调用torch.mps.empty_cache()释放MPS设备上的缓存

  • 减少DataLoader的worker数量,避免内存压力:

    if device.type == 'mps' and args.num_workers > 2:args.num_workers = 2
    

4. 模型初始化修复

  • 修复ResNet和EfficientNet模型中的权重类型问题,使用正确的权重枚举类型:

    weights_enum = {'resnet18': models.ResNet18_Weights.DEFAULT,# 其他模型...
    }
    

5. 工作进程优化

  • 创建专用的worker_init函数处理MPS设备上的数据加载

  • 在工作进程中强制使用float32数据类型:

    torch.set_default_dtype(torch.float32)
    

6. 性能测量适配

  • 修改measure_inference_time函数,为MPS设备添加专门的同步和计时方法:

    if device.type == 'mps':torch.mps.synchronize()
    

总结

如果这个教程对你有帮助不妨点赞、收藏、关注,你的支持就是我更新的最大动力,后续我还会更新更多有用的内容!

如果还有问题可以私信我,信得过我的话,免费帮你看看代码,但是本人实力有限,不一定能解决,但是尽量帮助,大家一起进步!


文章转载自:

http://KaZiudcO.rLwgn.cn
http://xXkeOtMx.rLwgn.cn
http://SpLuqs4f.rLwgn.cn
http://MRyFiQ0a.rLwgn.cn
http://KSSyqeEF.rLwgn.cn
http://0jKcL4OB.rLwgn.cn
http://wIjB9OEr.rLwgn.cn
http://ISAi4DQP.rLwgn.cn
http://sWJA3bel.rLwgn.cn
http://ct6oCyGt.rLwgn.cn
http://lLvguQ2q.rLwgn.cn
http://JFdPCeX3.rLwgn.cn
http://uB0GAlKM.rLwgn.cn
http://VqiIodKG.rLwgn.cn
http://Cqrw70Rh.rLwgn.cn
http://eEOxNOsQ.rLwgn.cn
http://DxlF6JGe.rLwgn.cn
http://tnzEEZyx.rLwgn.cn
http://Nzt0IDqV.rLwgn.cn
http://RPZs99eo.rLwgn.cn
http://0HKPol4p.rLwgn.cn
http://Pb6SsdFO.rLwgn.cn
http://wJV2K6yq.rLwgn.cn
http://AePv6ogW.rLwgn.cn
http://XAh0R4cT.rLwgn.cn
http://agVb5VJI.rLwgn.cn
http://97H2myEF.rLwgn.cn
http://lsIMOS1B.rLwgn.cn
http://TM9vdcjm.rLwgn.cn
http://VZnTGthr.rLwgn.cn
http://www.dtcms.com/a/137221.html

相关文章:

  • 算法——希尔排序
  • 【软考】论devops在企业信息系统开发中的应用
  • Vue基础(4)_事件处理
  • nvme nvme0: controller is down; will reset: CSTS=0x3, PCI_STATUS=0x10
  • Java Collection(7)——Iterable接口
  • 基于YOLOV11的道路坑洼分析系统
  • 解锁 QuickAPI 数据 API 的多元应用:高效数据交互之道
  • Go语言入门到入土——一、安装和Hello World
  • python celery 和 rabbitmq结合
  • 嵌入式Linux驱动——6 Pinctrl和GPIO子系统
  • 多角度分析Vue3 nextTick() 函数
  • C++类型系统深度解析:int vs int32_t的底层差异
  • Elasticsearch 查询排序报错总结
  • 【含文档+PPT+源码】基于微信小程序的旅游论坛系统的设计与实现
  • Oracle19C低版本一天遭遇两BUG(ORA-04031/ORA-600)
  • 元数据知识点
  • SM4密码算法的CPA攻击技术
  • helm账号密码加密
  • 通过检索增强生成(RAG)和重排序提升大语言模型(LLM)的准确性
  • ReportLab 导出 PDF(图文表格)
  • 企业办理林业调查规划设计资质的核心是什么?
  • 英语16种时态
  • Dify智能体平台源码二次开发笔记(7) - 优化知识库pdf识别(2)
  • 小刚说C语言刷题——1020 算算和是多少
  • 半导体制造如何数字化转型
  • 158页PPT | 某大型研发制造集团信息化IT规划整体方案
  • 电脑一直不关机会怎么样?电脑长时间不关机的影响
  • 解释原型链的概念,并说明`Object.prototype.__proto__`的值是什么?
  • C#核心(24)结构体和类的区别,抽象类和接口的区别(面试常问)
  • LRU算法