Keras模型保存、加载介绍
目录
- 前言:
- 保存格式
- 保存模型、参数和加载模型
- 示例
- 总结
前言:
在TensorFlow中,保存和加载模型是机器学习工作流程中的重要步骤。这不仅有助于持久化训练好的模型以便后续使用,还可以实现模型的版本控制、部署和服务。
保存格式
TensorFlow 提供了多种方式来保存和读取模型,主要分为两种格式:SavedModel和Keras的HDF5格式。
使用SavedModel格式
SavedModel是TensorFlow推荐的保存模型的方式。它保存整个模型,包括权重、架构、优化器状态等,并且支持 TensorFlow Serving 等工具。SavedModel 可以在不同平台上使用,并且可以恢复到任何语言的 TensorFlow API 中。
使用 HDF5 格式(Keras 模型)
HDF5 是一种二进制文件格式,适合保存大型数组数据,如神经网络的权重。Keras提供了简单的方法来保存和加载HDF5格式的模型。注意,HDF5文件只保存模型的架构和权重,不保存优化器的状态和其他配置。
保存模型、参数和加载模型
keras 保存成hdf5文件, 1,保存模型和参数;2, 只保存参数
保存模型和参数
##save
callback ModelCheckpoint
只保存参数
##save_weights
callback ModelCheckpoint save_weights_only=True
示例
##导包
from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt## 加载数据
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]##print(x_valid.shape, y_valid.shape)
##print(x_train.shape, y_train.shape)
##print(x_test.shape, y_test.shape)## 标准化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(55000, -1))
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(5000, -1))x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(10000, -1))## 创建模型model = keras.models.Sequential([keras.layers.Dense(512, activation='relu', input_shape=(784,)),keras.layers.Dense(256, activation='relu'),keras.layers.Dense(128, activation='relu'),keras.layers.Dense(10, activation='softmax'),
])model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])## 训练import oslogdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
callbacks = [keras.callbacks.TensorBoard(logdir),keras.callbacks.ModelCheckpoint(output_model_file, save_best_only=True,save_weights_only=True),keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)
]
history = model.fit(x_train_scaled, y_train, epochs=10, validation_data=(x_valid_scaled, y_valid),callbacks=callbacks)##保存模型
output_model_file2 = os.path.join(logdir, 'fashion_mnist_model.h5')
# 另一种保存参数的方法
##model.save_weights(os.path.join(logdir, 'fashin_mnist_weights_2.h5'))
model.save(output_model_file2)print(model.evaluate(x_valid_scaled, y_valid))##加载模型
model2 = keras.models.load_model(output_model_file2)print(model2.evaluate(x_valid_scaled, y_valid))
结果如下:
[0.3386252820491791, 0.8938000202178955]
[0.3386252820491791, 0.8938000202178955]
总结
SavedModel:推荐用于生产环境,因为它保存了完整的模型信息,并且具有良好的跨平台兼容性。
HDF5:适用于简单的模型保存和加载需求,特别是当你需要与旧版本的 TensorFlow 或其他库兼容时。