当前位置: 首页 > news >正文

深度学习-读写模型网络文件

模型网络文件是深度学习模型的存储形式,保存了模型的架构、参数等信息。

读写模型网络文件是深度学习流程中的关键环节,方便模型的训练、测试、部署与共享。

1. 主流框架读写方法

(一)TensorFlow

  • 保存模型

    • 可以使用 tf.saved_model.save 方法保存整个模型,包括架构、参数、编译信息等。例如: model.save('model_dir', save_format='tf'),将模型保存在文件夹 'model_dir' 中。

  • 加载模型

    • 使用 tf.keras.models.load_model 加载保存的模型。如:loaded_model = tf.keras.models.load_model('model_dir'),即可加载之前保存的模型进行预测、继续训练等操作。

(二)PyTorch

使用 torch.save 和 torch.load 来保存和加载 张量

  • 保存模型

    • 通常有两种方式:一种是保存整个模型对象,使用 torch.save(model, 'model.pth'),将模型结构和参数都保存下来。另一种是仅保存模型的参数状态字典,即 torch.save(model.state_dict(), 'model_state_dict.pth'),这种方式更常见,因为当模型架构修改时,只要能正确加载参数,就无需重新训练整个模型。

  • 加载模型

    • 对于保存整个模型的情况,直接使用 model = torch.load('model.pth')。对于仅保存参数的情况,先定义好模型架构,再用 model.load_state_dict(torch.load('model_state_dict.pth')) 加载参数,使模型具备相应的能力。

对于深度学习模型而言,通常只需保存其权重参数即可满足需求。在 PyTorch 框架中,可以使用 torch.save() 函数来保存网络的 state_dict 参数,这是保存模型权重的一种高效方式。

而在加载模型权重时,可以借助网络的 load_state_dict() 方法,搭配 torch.load() 函数来实现对网络参数的读取,从而恢复模型的训练状态和性能表现。

2. 模型保存示例

torch.save(model.state_dict(), path)只保存“参数”(一个纯字典),文件小、加载灵活。

torch.save(model.state_dict(), "best_model.pt")

1. 加载时必须先重新建网络,再把参数填进去:

new_model = MyNet()                      # 重新建图
new_model.load_state_dict(torch.load("best_model.pt"))
new_model.eval()                         # 记得切到推理模式

2. 优点

  • 文件 ≈ 仅参数大小,磁盘占用小
  • 不关心原始类定义,跨代码版本更稳

3. 缺点

        需要手动重建网络结构才能用

torch.save(model, path):把整个模型(结构+参数)序列化为一个 Pickle 对象,一步到位。

torch.save(model, "full_model.pt")

1. 加载极其简单:

model = torch.load("full_model.pt")      # 结构+参数全回来
model.eval()

2. 优点

        一行代码即可复现模型,适合快速分享、断点继续训练

3. 缺点

  • Pickle 会硬编码类定义路径,代码位置/类名一变就加载失败

  • 文件更大(含结构+参数)

选用建议

  • 生产/长期维护 → 用 state_dict(稳妥、小、可迁移)。

  • 临时 checkpoint / 本地快速实验 → 用 完整模型(省事)。

http://www.dtcms.com/a/313624.html

相关文章:

  • 大模型设计
  • 学习方法论
  • 智能化设备维护:开启高效运维新时代
  • 前端异步任务处理总结
  • Maven - 依赖的生命周期详解
  • 服务端技术栈分类总结
  • 模型预估打分对运筹跟踪的影响
  • 数据结构:单向链表的函数创建
  • [硬件电路-141]:模拟电路 - 源电路,信号源与电源,能自己产生确定性波形的电路。
  • 高质量数据集|大模型技术正从根本上改变传统数据工程的工作模式
  • RapidIO/SRIO 入门之什么是SRIO
  • 环绕字符串中的唯一子字符串-动态规划
  • [2025ICCV-目标检测方向]DuET:通过无示例任务算术进行双增量对象检测
  • 1.内核模块
  • C语言基础03——数组——习题
  • 工作笔记-----IAP的相关内容
  • 8大图床高速稳定网站,值得长期选用
  • 【最长公共前缀】
  • DMDRS产品概述和安装部署
  • Kaggle 竞赛入门指南
  • Pygame如何制作小游戏
  • vllm0.8.5:自定义聊天模板qwen_nonthinking.jinja,从根本上避免模型输出<think>标签
  • Docker环境离线安装指南
  • C++与Go的匿名函数编程区别对比
  • SPI入门(基于ESP-IDF-v5.4.1)
  • accept4系统调用及示例
  • ELECTRICAL靶场
  • 检索召回率优化探究三:基于LangChain0.3集成Milvu2.5向量数据库构建的智能问答系统
  • 思途JSP学习 0802(项目完整流程)
  • Fay数字人如何使用GPT-SOVITS进行TTS转换以及遇到的一些问题