当前位置: 首页 > news >正文

深度学习篇---模型参数保存

在深度学习模型训练和部署过程中,模型保存是一个关键环节。不同框架在模型保存的实现上既有相似之处,也有各自的特点。下面详细介绍 PyTorch、TensorFlow 和 PaddlePaddle 中模型保存的代码及保存内容:

1. PyTorch

PyTorch 提供了灵活的模型保存方式,主要通过torch.save()函数实现,可保存模型结构、参数或训练状态。

(1)保存模型参数(推荐)

仅保存模型的参数(权重和偏置),不包含模型结构,文件体积较小。

import torch
import torch.nn as nn# 定义示例模型
class SimpleModel(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)model = SimpleModel()# 保存模型参数(状态字典,state_dict)
torch.save(model.state_dict(), "model_params.pth")
  • 保存内容:模型的state_dict,是一个字典,层名称对应参数的张量
  • 用途:适用于训练中断后恢复训练,或在已知模型结构的情况下加载参数。
(2)保存完整模型

保存整个模型(包括结构和参数),但可能存在兼容性问题(如不同 PyTorch 版本或 Python 环境)。

# 保存完整模型
torch.save(model, "full_model.pth")
  • 保存内容:模型的类结构、参数及其他属性(如训练配置)。
  • 注意:不推荐用于跨环境部署,可能因类定义变化导致加载失败。
(3)保存训练过程状态(断点续训)

保存模型参数、优化器状态、epoch 等信息,用于中断后继续训练。

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch = 10
loss = 0.123# 保存训练状态
checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": loss
}
torch.save(checkpoint, "checkpoint.pth")
  • 保存内容:模型参数、优化器参数(如动量、学习率)、当前训练轮次、损失值等。

2. TensorFlow(Keras)

TensorFlow(尤其是 Keras 接口)提供了多种模型保存方式,支持 SavedModel 格式(推荐)和 HDF5 格式。

(1)保存完整模型(SavedModel 格式,推荐)

SavedModel 是 TensorFlow 的标准格式,包含模型结构、参数、计算图等,兼容性强。

  • 保存内容
    • 模型结构(网络层、输入输出形状);
    • 所有参数(权重和偏置);
    • 训练配置(优化器、损失函数、 metrics);
    • 计算图(用于部署到 TensorFlow Serving、移动端等)。
  • 用途:模型部署、跨平台使用(如 TensorFlow Lite、TensorRT)。
(2)保存为 HDF5 格式

保存模型结构和参数到单一文件,适用于简单场景。

# 保存为HDF5格式
model.save("model.h5")
  • 保存内容:模型结构(JSON 格式)和参数(二进制),但不包含计算图细节。
  • 注意:对复杂模型(如自定义层、控制流)的兼容性较差。
(3)保存权重(仅参数)

仅保存模型参数,需已知模型结构才能加载。

# 保存权重
model.save_weights("model_weights.h5")
  • 保存内容:各层的权重张量,不包含模型结构。
(4)训练过程保存(Checkpoint)

通过ModelCheckpoint回调保存训练过程中的模型状态。

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath="training_checkpoint",save_weights_only=False,  # 是否仅保存权重save_best_only=True,      # 仅保存性能最好的模型monitor="val_loss"        # 监控指标
)# 训练时使用回调
model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint_callback])
  • 保存内容:根据配置,可保存完整模型或仅权重,支持按指标(如验证集损失)保存最优模型。

3. PaddlePaddle

PaddlePaddle 的模型保存逻辑与 PyTorch 类似,主要通过paddle.save()Model.save()实现。

(1)保存模型参数(推荐)

仅保存模型参数,需结合模型结构加载。

import paddle
from paddle.nn import Linear# 定义示例模型
class SimpleModel(paddle.nn.Layer):def __init__(self):super().__init__()self.fc = Linear(in_features=10, out_features=2)def forward(self, x):return self.fc(x)model = SimpleModel()# 保存模型参数
paddle.save(model.state_dict(), "model_params.pdparams")
  • 保存内容:模型的state_dict,键为层名称,值为参数张量。
(2)保存完整模型

保存模型结构和参数,方便直接加载使用。

# 保存完整模型
paddle.Model(model).save("full_model")
  • 保存内容:模型结构(__model__文件)和参数(*.pdparams),支持跨环境加载。
(3)保存训练过程状态(断点续训)

保存模型参数、优化器状态、训练轮次等。

optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.001)
epoch = 10
loss = 0.123# 保存训练状态
checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": loss
}
paddle.save(checkpoint, "checkpoint.pdparams")
  • 保存内容:模型参数、优化器参数(如学习率、动量)、训练进度等。

总结

框架保存类型核心函数 / 方法主要保存内容
PyTorch仅参数torch.save(model.state_dict(), ...)模型参数(state_dict)
完整模型torch.save(model, ...)模型结构 + 参数
训练状态(断点续训)torch.save(checkpoint_dict, ...)模型参数 + 优化器状态 + 训练进度
TensorFlow完整模型(推荐)model.save("saved_model")结构 + 参数 + 计算图 + 训练配置
HDF5 格式model.save("model.h5")结构 + 参数(兼容性有限)
仅参数model.save_weights(...)各层权重
训练过程检查点ModelCheckpoint回调按配置保存模型或权重(支持最优模型选择)
PaddlePaddle仅参数paddle.save(model.state_dict(), ...)模型参数(state_dict)
完整模型paddle.Model(model).save(...)结构 + 参数
训练状态(断点续训)paddle.save(checkpoint_dict, ...)模型参数 + 优化器状态 + 训练进度

实际应用中,仅保存参数通常是最灵活和高效的方式(需配合模型结构加载);完整模型适合快速部署但需注意兼容性;训练状态保存则用于中断后恢复训练。

http://www.dtcms.com/a/354987.html

相关文章:

  • 卷积神经网络实现mnist手写数字集识别案例
  • Apollo-PETRv1演示DEMO操作指南
  • 【Qt】QCryptographicHash 设置密钥(Key)
  • Deeplizard 深度学习课程(四)—— 模型构建
  • jwt原理及Java中实现
  • 海盗王64位dx9客户端修改篇之二
  • 学习Java29天(tcp多发多收)但是无解决客户端启动多个问题
  • ProfiNet 转 Ethernet/IP 柔性产线构建方案:网关技术保护新能源企业现有设备投资
  • LeetCode Hot 100 第7天
  • 第三十天:世界杯队伍团结力问题
  • EF Core 编译模型 / 模型裁剪:冷启动与查询优化
  • QT之双缓冲 (QMutex/QWaitCondition)——读写分离
  • 企业如何管理跨多个系统的主数据?
  • MaxCompute MaxFrame | 分布式Python计算服务MaxFrame(完整操作版)
  • 【Lua】题目小练12
  • 如何实现HTML动态爱心表白效果?
  • 多版本并发控制MVCC
  • 黑马点评|项目日记(day02)
  • C#和Lua相互访问
  • 基于金庸武侠小说人物关系设计的完整 SQL 语句,包括数据库创建、表结构定义和示例数据插入
  • Docker 详解+示例
  • map底层的数据结构是什么,为什么不用AVL树
  • 机器学习回顾(一)
  • 陪诊小程序系统开发:搭建医患之间的温暖桥梁
  • Scrapy 基础介绍
  • 安全运维——系统上线前安全检测:漏洞扫描、系统基线与应用基线的全面解析
  • lwIP MQTT 心跳 Bug 分析与修复
  • 边缘计算(Edge Computing)+ AI:未来智能世界的核心引擎
  • HarmonyOS 组件与页面生命周期:全面解析与实践
  • Paimon——官网阅读:Flink 引擎