当前位置: 首页 > news >正文

PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

# 保存
torch.save(model.state_dict(), "model_params.pth")  # 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
# 保存
torch.save(model, "full_model.pth")  # 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
# 保存检查点
checkpoint = {'epoch': current_epoch,'model_state': model.state_dict(),'optimizer_state': optimizer.state_dict(),'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关文章:

  • 【软考-架构】15、软件架构的演化和维护
  • 力扣热题100之删除链表的倒数第N个节点
  • 鸿蒙 Location Kit(位置服务)
  • 双周报Vol.72:字段级文档注释支持、视图类型现为值类型,减少内存分配
  • Python网络爬虫入门指南
  • 【CodeBuddy 】从0到1,让网页导航栏变为摸鱼神器
  • 视图+触发器+临时表+派生表
  • 用于判断主子关系的方法的实现(orm是efcore)
  • [特殊字符] Word2Vec:将词映射到高维空间,它到底能解决什么问题?
  • 深入解析OkHttp与Retrofit:Android网络请求的黄金组合
  • 蓝桥杯1447 砝码称重
  • Python 实例传递的艺术:四大方法解析与最佳实践
  • 用 RefCounted + WeakPtr 构建线程安全的异步模块
  • 【OpenCV基础2】图像运算、水印、加密、摄像头
  • 如何在 Windows 11 或 10 上安装 FlutterFire CLI
  • CSS提高性能的方法有哪些
  • C++面试4-sizeof解析
  • RabbitMQ的简介
  • C 语言学习笔记(函数2)
  • AI在网络安全中的应用之钓鱼邮件检测
  • 不止是生态优势,“浙江绿谷”丽水有活力
  • 张永宁任福建宁德市委书记
  • 印尼总统20年来首次访泰:建立战略伙伴关系,加强打击网络诈骗等合作
  • 贵州茅台股东大会回应八大热点:确保茅台酒价格体系稳固,相信自我调节能力
  • 世卫大会再次拒绝涉台提案,国台办:民进党当局再遭挫败理所当然
  • 陈龙带你观察上海生物多样性,纪录片《我的城市邻居》明播出