深入理解 TensorFlow 的模型保存与加载机制(SavedModel vs H5)
深入理解 TensorFlow 的模型保存与加载机制(SavedModel vs H5)
在使用 TensorFlow 进行模型训练后,模型的保存与加载是部署、复用和迁移学习的重要环节。TensorFlow 提供了两种主要的保存格式:
SavedModel
和HDF5 (.h5)
。本篇文章将详细对比它们的异同,并通过代码实战帮你掌握使用方法。
📦 一、为什么需要保存模型?
在训练完一个神经网络模型后,通常需要将模型持久化用于:
- 模型部署(线上服务)
- 迁移学习
- 断点训练(Resume Training)
- 团队共享模型
TensorFlow 支持以下两种主流保存方式:
格式 | 文件扩展名 | 支持特性 |
---|---|---|
SavedModel | 无扩展名(文件夹) | ✅ 推荐格式,包含完整计算图,支持多语言部署(TF Serving、TensorFlow Lite 等) |
HDF5 | .h5 | ✅ Keras 风格保存,适合快速保存和加载模型 |
📂 二、SavedModel 格式详解
✅ 特点:
- 官方推荐格式。
- 保存了计算图、变量值、优化器状态等全部信息。
- 适用于 TensorFlow Serving、TensorFlow Lite、TF.js 等部署场景。
- 支持自定义对象(如自定义层、自定义训练逻辑)。
🛠 保存模型:
model.save("my_model") # 保存为SavedModel格式(默认)
会生成一个目录:
my_model/
├── assets/
├── variables/
│ ├── variables.data-00000-of-00001
│ └── variables.index
└── saved_model.pb
📥 加载模型:
loaded_model = tf.keras.models.load_model("my_model")
可以继续训练或直接用于预测。
💾 三、HDF5(.h5)格式详解
✅ 特点:
- 更接近早期 Keras 用户的使用习惯。
- 使用一个单一的
.h5
文件保存全部信息(结构、权重、优化器状态)。 - 不兼容 TensorFlow Serving。
🛠 保存模型:
model.save("my_model.h5") # 显式指定保存为HDF5格式
📥 加载模型:
loaded_model = tf.keras.models.load_model("my_model.h5")
⚠ 注意:如使用自定义层或自定义训练函数,加载时需使用
custom_objects
参数指定。
🔄 四、对比:SavedModel vs H5
对比项 | SavedModel | HDF5 (.h5) |
---|---|---|
文件形式 | 文件夹 | 单一文件 |
保存信息 | 结构 + 权重 + 优化器状态 + 计算图 | 同上(不含完整计算图) |
多语言部署 | ✅ 支持 | ❌ 不支持 |
TensorFlow Serving | ✅ 支持 | ❌ 不支持 |
TensorFlow Lite 支持 | ✅ 支持 | ❌ 不支持 |
自定义训练逻辑支持 | ✅ 更好 | ✅ 有限支持 |
文件大小 | 稍大 | 相对较小 |
🧪 五、实战代码对比
以下是一个完整的模型保存与加载实战代码:
import tensorflow as tf
from tensorflow.keras import layers, models# 构建简单模型
model = models.Sequential([layers.Dense(64, activation='relu', input_shape=(100,)),layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')# 保存为 SavedModel
model.save("model_savedmodel")# 保存为 HDF5 格式
model.save("model.h5")# 加载 SavedModel
model1 = tf.keras.models.load_model("model_savedmodel")# 加载 HDF5
model2 = tf.keras.models.load_model("model.h5")
🔐 六、进阶话题:只保存权重 vs 保存结构
1. 只保存权重
model.save_weights("weights.h5")
加载:
model = create_model() # 需先定义好模型结构
model.load_weights("weights.h5")
2. 保存结构(不含权重)
# 保存JSON格式结构
json_str = model.to_json()
加载结构:
model = tf.keras.models.model_from_json(json_str)
✅ 七、结语:选择哪种格式?
- 如果你是部署服务或计划使用 TensorFlow Serving、TensorFlow Lite:推荐 SavedModel。
- 如果你是快速实验、迁移学习或保存简单模型:HDF5 更方便。
- 如果只是保存参数,用于 Resume Training:
save_weights()
即可。
📌 小贴士
-
model.save()
不指定扩展名时默认保存为 SavedModel。 -
加载模型时也可以查看其结构和权重是否正确:
model.summary()