当前位置: 首页 > news >正文

【Mac-ML-DL】深度学习使用MPS出现内存泄露(leaked semaphore)以及张量转换错误

MPS加速修改总结

先说设备:MacBook Pro M4 24GB
事情的起因是我在进行深度学习的时候想尝试用苹果自带的MPS进行训练加速,修改设备后准备开始训练,但是出现如下报错:

UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdownwarnings.warn('resource_tracker: There appear to be %d '

我尝试在stackoverflow上面寻找答案,但是只有人提问,没有人回答,于是我进入PyTorch的社区进行查找,终于有人也提了这个问题。
修改后没有出现内存泄露的问题,但是有新的问题:

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

幸好有人直接给出了GitHub的issue链接,根据官方的建议修改后成功运行了,接下来分享我的修改全流程,帮助大家避坑。

1. 设备检测与切换

  • train.py中添加了对MPS设备的检测和使用:

    if torch.backends.mps.is_available() and torch.backends.mps.is_built():device = torch.device("mps")
    
  • 添加命令行参数支持直接指定设备类型:--device mps

2. 数据类型修复

  • 创建DoubleToFloatTransform转换器确保所有张量为float32类型,因为MPS不支持float64

  • 在数据转换pipeline中添加此转换器:

    transforms.Compose([# 其他转换...transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),DoubleToFloatTransform()  # 确保张量为float32类型
    ])
    
  • 将代码中所有.double()替换为.float(),避免精度转换错误

3. 内存管理优化

  • 使用multiprocessing.set_start_method('spawn')解决MPS设备上的bus error问题

  • 定期调用torch.mps.empty_cache()释放MPS设备上的缓存

  • 减少DataLoader的worker数量,避免内存压力:

    if device.type == 'mps' and args.num_workers > 2:args.num_workers = 2
    

4. 模型初始化修复

  • 修复ResNet和EfficientNet模型中的权重类型问题,使用正确的权重枚举类型:

    weights_enum = {'resnet18': models.ResNet18_Weights.DEFAULT,# 其他模型...
    }
    

5. 工作进程优化

  • 创建专用的worker_init函数处理MPS设备上的数据加载

  • 在工作进程中强制使用float32数据类型:

    torch.set_default_dtype(torch.float32)
    

6. 性能测量适配

  • 修改measure_inference_time函数,为MPS设备添加专门的同步和计时方法:

    if device.type == 'mps':torch.mps.synchronize()
    

总结

如果这个教程对你有帮助不妨点赞、收藏、关注,你的支持就是我更新的最大动力,后续我还会更新更多有用的内容!

如果还有问题可以私信我,信得过我的话,免费帮你看看代码,但是本人实力有限,不一定能解决,但是尽量帮助,大家一起进步!

相关文章:

  • 算法——希尔排序
  • 【软考】论devops在企业信息系统开发中的应用
  • Vue基础(4)_事件处理
  • nvme nvme0: controller is down; will reset: CSTS=0x3, PCI_STATUS=0x10
  • Java Collection(7)——Iterable接口
  • 基于YOLOV11的道路坑洼分析系统
  • 解锁 QuickAPI 数据 API 的多元应用:高效数据交互之道
  • Go语言入门到入土——一、安装和Hello World
  • python celery 和 rabbitmq结合
  • 嵌入式Linux驱动——6 Pinctrl和GPIO子系统
  • 多角度分析Vue3 nextTick() 函数
  • C++类型系统深度解析:int vs int32_t的底层差异
  • Elasticsearch 查询排序报错总结
  • 【含文档+PPT+源码】基于微信小程序的旅游论坛系统的设计与实现
  • Oracle19C低版本一天遭遇两BUG(ORA-04031/ORA-600)
  • 元数据知识点
  • SM4密码算法的CPA攻击技术
  • helm账号密码加密
  • 通过检索增强生成(RAG)和重排序提升大语言模型(LLM)的准确性
  • ReportLab 导出 PDF(图文表格)
  • 邯郸媒体网络营销诚信合作/西安seo网络优化公司
  • 北京官网建设多少钱/无锡网站优化公司
  • 仓库管理系统erp/seo全称是什么
  • 苏州调查公司哪家好/广州中小企业seo推广运营
  • 一个企业网站做几个关键词/佛山做网站推广的公司
  • 微商产品做网站/互联网营销推广服务商