深度学习-读写模型网络文件
模型网络文件是深度学习模型的存储形式,保存了模型的架构、参数等信息。
读写模型网络文件是深度学习流程中的关键环节,方便模型的训练、测试、部署与共享。
1. 主流框架读写方法
(一)TensorFlow
保存模型
可以使用
tf.saved_model.save
方法保存整个模型,包括架构、参数、编译信息等。例如:model.save('model_dir', save_format='tf')
,将模型保存在文件夹 'model_dir' 中。
加载模型
使用
tf.keras.models.load_model
加载保存的模型。如:loaded_model = tf.keras.models.load_model('model_dir')
,即可加载之前保存的模型进行预测、继续训练等操作。
(二)PyTorch
使用 torch.save 和 torch.load 来保存和加载 张量。
保存模型
通常有两种方式:一种是保存整个模型对象,使用
torch.save(model, 'model.pth')
,将模型结构和参数都保存下来。另一种是仅保存模型的参数状态字典,即torch.save(model.state_dict(), 'model_state_dict.pth')
,这种方式更常见,因为当模型架构修改时,只要能正确加载参数,就无需重新训练整个模型。
加载模型
对于保存整个模型的情况,直接使用
model = torch.load('model.pth')
。对于仅保存参数的情况,先定义好模型架构,再用model.load_state_dict(torch.load('model_state_dict.pth'))
加载参数,使模型具备相应的能力。
对于深度学习模型而言,通常只需保存其权重参数即可满足需求。在 PyTorch 框架中,可以使用 torch.save() 函数来保存网络的 state_dict 参数,这是保存模型权重的一种高效方式。
而在加载模型权重时,可以借助网络的 load_state_dict() 方法,搭配 torch.load() 函数来实现对网络参数的读取,从而恢复模型的训练状态和性能表现。
2. 模型保存示例
torch.save(model.state_dict(), path)
只保存“参数”(一个纯字典),文件小、加载灵活。
torch.save(model.state_dict(), "best_model.pt")
1. 加载时必须先重新建网络,再把参数填进去:
new_model = MyNet() # 重新建图
new_model.load_state_dict(torch.load("best_model.pt"))
new_model.eval() # 记得切到推理模式
2. 优点
- 文件 ≈ 仅参数大小,磁盘占用小
- 不关心原始类定义,跨代码版本更稳
3. 缺点
需要手动重建网络结构才能用
torch.save(model, path)
:把整个模型(结构+参数)序列化为一个 Pickle 对象,一步到位。
torch.save(model, "full_model.pt")
1. 加载极其简单:
model = torch.load("full_model.pt") # 结构+参数全回来
model.eval()
2. 优点
一行代码即可复现模型,适合快速分享、断点继续训练
3. 缺点
Pickle 会硬编码类定义路径,代码位置/类名一变就加载失败
文件更大(含结构+参数)
选用建议
生产/长期维护 → 用 state_dict(稳妥、小、可迁移)。
临时 checkpoint / 本地快速实验 → 用 完整模型(省事)。