当前位置: 首页 > news >正文

PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

# 保存
torch.save(model.state_dict(), "model_params.pth")  # 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
# 保存
torch.save(model, "full_model.pth")  # 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
# 保存检查点
checkpoint = {'epoch': current_epoch,'model_state': model.state_dict(),'optimizer_state': optimizer.state_dict(),'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

http://www.dtcms.com/a/200675.html

相关文章:

  • 【软考-架构】15、软件架构的演化和维护
  • 力扣热题100之删除链表的倒数第N个节点
  • 鸿蒙 Location Kit(位置服务)
  • 双周报Vol.72:字段级文档注释支持、视图类型现为值类型,减少内存分配
  • Python网络爬虫入门指南
  • 【CodeBuddy 】从0到1,让网页导航栏变为摸鱼神器
  • 视图+触发器+临时表+派生表
  • 用于判断主子关系的方法的实现(orm是efcore)
  • [特殊字符] Word2Vec:将词映射到高维空间,它到底能解决什么问题?
  • 深入解析OkHttp与Retrofit:Android网络请求的黄金组合
  • 蓝桥杯1447 砝码称重
  • Python 实例传递的艺术:四大方法解析与最佳实践
  • 用 RefCounted + WeakPtr 构建线程安全的异步模块
  • 【OpenCV基础2】图像运算、水印、加密、摄像头
  • 如何在 Windows 11 或 10 上安装 FlutterFire CLI
  • CSS提高性能的方法有哪些
  • C++面试4-sizeof解析
  • RabbitMQ的简介
  • C 语言学习笔记(函数2)
  • AI在网络安全中的应用之钓鱼邮件检测
  • Python列表 vs 元组:全面对比解析(新手友好版)
  • MYSQL8.0常用窗口函数
  • input组件使用type=“number“的时候,光标自动跳到首位
  • 【Tools】VMware Workstation 17.6 Pro安装教程
  • 在 CentOS 7.9 上部署 node_exporter 并接入 Prometheus + Grafana 实现主机监控
  • PyMOL命令行和脚本
  • 精益数据分析(70/126):MVP迭代中的数据驱动决策与功能取舍
  • AI神经网络降噪 vs 传统单/双麦克风降噪的核心优势对比
  • 公网ip是固定的吗?动态ip如何做端口映射?内网ip怎么让外网远程访问?
  • 组态王通过开疆智能profinet转ModbusTCP网关连接西门子PLC配置案例