深度学习篇---模型参数保存
在深度学习模型训练和部署过程中,模型保存是一个关键环节。不同框架在模型保存的实现上既有相似之处,也有各自的特点。下面详细介绍 PyTorch、TensorFlow 和 PaddlePaddle 中模型保存的代码及保存内容:
1. PyTorch
PyTorch 提供了灵活的模型保存方式,主要通过torch.save()
函数实现,可保存模型结构、参数或训练状态。
(1)保存模型参数(推荐)
仅保存模型的参数(权重和偏置),不包含模型结构,文件体积较小。
import torch
import torch.nn as nn# 定义示例模型
class SimpleModel(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)model = SimpleModel()# 保存模型参数(状态字典,state_dict)
torch.save(model.state_dict(), "model_params.pth")
- 保存内容:模型的
state_dict
,是一个字典,键为层名称,值为对应参数的张量。 - 用途:适用于训练中断后恢复训练,或在已知模型结构的情况下加载参数。
(2)保存完整模型
保存整个模型(包括结构和参数),但可能存在兼容性问题(如不同 PyTorch 版本或 Python 环境)。
# 保存完整模型
torch.save(model, "full_model.pth")
- 保存内容:模型的类结构、参数及其他属性(如训练配置)。
- 注意:不推荐用于跨环境部署,可能因类定义变化导致加载失败。
(3)保存训练过程状态(断点续训)
保存模型参数、优化器状态、epoch 等信息,用于中断后继续训练。
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch = 10
loss = 0.123# 保存训练状态
checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": loss
}
torch.save(checkpoint, "checkpoint.pth")
- 保存内容:模型参数、优化器参数(如动量、学习率)、当前训练轮次、损失值等。
2. TensorFlow(Keras)
TensorFlow(尤其是 Keras 接口)提供了多种模型保存方式,支持 SavedModel 格式(推荐)和 HDF5 格式。
(1)保存完整模型(SavedModel 格式,推荐)
SavedModel 是 TensorFlow 的标准格式,包含模型结构、参数、计算图等,兼容性强。
- 保存内容:
- 模型结构(网络层、输入输出形状);
- 所有参数(权重和偏置);
- 训练配置(优化器、损失函数、 metrics);
- 计算图(用于部署到 TensorFlow Serving、移动端等)。
- 用途:模型部署、跨平台使用(如 TensorFlow Lite、TensorRT)。
(2)保存为 HDF5 格式
保存模型结构和参数到单一文件,适用于简单场景。
# 保存为HDF5格式
model.save("model.h5")
- 保存内容:模型结构(JSON 格式)和参数(二进制),但不包含计算图细节。
- 注意:对复杂模型(如自定义层、控制流)的兼容性较差。
(3)保存权重(仅参数)
仅保存模型参数,需已知模型结构才能加载。
# 保存权重
model.save_weights("model_weights.h5")
- 保存内容:各层的权重张量,不包含模型结构。
(4)训练过程保存(Checkpoint)
通过ModelCheckpoint
回调保存训练过程中的模型状态。
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath="training_checkpoint",save_weights_only=False, # 是否仅保存权重save_best_only=True, # 仅保存性能最好的模型monitor="val_loss" # 监控指标
)# 训练时使用回调
model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint_callback])
- 保存内容:根据配置,可保存完整模型或仅权重,支持按指标(如验证集损失)保存最优模型。
3. PaddlePaddle
PaddlePaddle 的模型保存逻辑与 PyTorch 类似,主要通过paddle.save()
和Model.save()
实现。
(1)保存模型参数(推荐)
仅保存模型参数,需结合模型结构加载。
import paddle
from paddle.nn import Linear# 定义示例模型
class SimpleModel(paddle.nn.Layer):def __init__(self):super().__init__()self.fc = Linear(in_features=10, out_features=2)def forward(self, x):return self.fc(x)model = SimpleModel()# 保存模型参数
paddle.save(model.state_dict(), "model_params.pdparams")
- 保存内容:模型的
state_dict
,键为层名称,值为参数张量。
(2)保存完整模型
保存模型结构和参数,方便直接加载使用。
# 保存完整模型
paddle.Model(model).save("full_model")
- 保存内容:模型结构(
__model__
文件)和参数(*.pdparams
),支持跨环境加载。
(3)保存训练过程状态(断点续训)
保存模型参数、优化器状态、训练轮次等。
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.001)
epoch = 10
loss = 0.123# 保存训练状态
checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": loss
}
paddle.save(checkpoint, "checkpoint.pdparams")
- 保存内容:模型参数、优化器参数(如学习率、动量)、训练进度等。
总结
框架 | 保存类型 | 核心函数 / 方法 | 主要保存内容 |
---|---|---|---|
PyTorch | 仅参数 | torch.save(model.state_dict(), ...) | 模型参数(state_dict) |
完整模型 | torch.save(model, ...) | 模型结构 + 参数 | |
训练状态(断点续训) | torch.save(checkpoint_dict, ...) | 模型参数 + 优化器状态 + 训练进度 | |
TensorFlow | 完整模型(推荐) | model.save("saved_model") | 结构 + 参数 + 计算图 + 训练配置 |
HDF5 格式 | model.save("model.h5") | 结构 + 参数(兼容性有限) | |
仅参数 | model.save_weights(...) | 各层权重 | |
训练过程检查点 | ModelCheckpoint 回调 | 按配置保存模型或权重(支持最优模型选择) | |
PaddlePaddle | 仅参数 | paddle.save(model.state_dict(), ...) | 模型参数(state_dict) |
完整模型 | paddle.Model(model).save(...) | 结构 + 参数 | |
训练状态(断点续训) | paddle.save(checkpoint_dict, ...) | 模型参数 + 优化器状态 + 训练进度 |
实际应用中,仅保存参数通常是最灵活和高效的方式(需配合模型结构加载);完整模型适合快速部署但需注意兼容性;训练状态保存则用于中断后恢复训练。