当前位置: 首页 > news >正文

Tensorflow2保存和加载模型

1、model.save() and model.load()

此种方法可保存模型的结构、参数等内容。加载模型后无需设置即可使用!

保存模型:

model.save('my_model.h5')

加载模型:

# 加载整个模型
loaded_model = tf.keras.models.load_model('my_model.h5')

注意,创建的模型不能使用自定义的loss函数等方法,否则导入时会出错!

示例:

model_file = "data/model/multi_labels_model.h5"    # 模型文件路径
def model_handle(x_train, y_train):if os.path.exists(model_file):print("---load the model---")model = tf.keras.models.load_model(model_file) # 导入已存在的模型else:# 模型构建model = tf.keras.Sequential([tf.keras.layers.LSTM(128),tf.keras.layers.Dense(class_num, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())])# 编译模型,不能使用自定义函数方法,否则导入模型会有问题model.compile(loss="BinaryCrossentropy", optimizer='adam', metrics=['accuracy'])history = model.fit(x_train, y_train, epochs=epoch_num, batch_size=1, verbose=1, callbacks=[PrintPredictionsCallback(x_train, y_train)])model.summary()model.save(model_file)return model

2、model.save_weight() and model.load_weight()

此方法只保存和加载模型的权重。

保存权重:

# 只保存权重
model.save_weights('my_model_weights.h5')

加载权重:

# 创建一个新的模型实例(确保架构与原始模型相同)
new_model = tf.keras.models.Sequential([tf.keras.layers.Dense(10, activation='relu', input_shape=(32,)),tf.keras.layers.Dense(1)
])
# new_model.build(input_shape=x_train.shape) # 如果模型创建时没有规定input_shape,需要创建
# 加载权重到新模型
new_model.load_weights('my_model_weights.h5')

此方法的模型可以使用自定义的函数方法。

注意:以H5格式加载子类模型的参数时,需要提前建立模型,规定输入网络的shape,否则会报错!

ValueError: Unable to load weights saved in HDF5 format into a subclassed Model which has not created its variables yet. Call the Model first, then load the weights.

示例:

def model_handle(x_train, y_train):# 模型构建,多分类的激活函数使用sigmoid 或 softmaxmodel = tf.keras.Sequential([tf.keras.layers.LSTM(128),tf.keras.layers.Dense(class_num, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())])if os.path.exists(model_file):print("-----load model weights-----")model.build(input_shape=x_train.shape)  # 以H5格式加载子类模型的参数时,需要提前建立模型,规定输入网络的shape,否则会报错model.load_weights(model_file)else:# 编译模型,使用自定义loss函数model.compile(loss=custom_loss, optimizer='adam', metrics=['accuracy'])# model.compile(loss="BinaryCrossentropy", optimizer='adam', metrics=['accuracy'])history = model.fit(x_train, y_train, epochs=epoch_num, batch_size=1, verbose=1, callbacks=[PrintPredictionsCallback(x_train, y_train)])model.summary()model.save_weights(model_file)return model

3、model.checkpoint

主要是用于模型的断点续训。用法参考如下:

checkpoint_save_path = "./checkpoint/my_checkpoint.ckpt"if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True,monitor='val_loss')history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])model.summary()

http://www.dtcms.com/a/190994.html

相关文章:

  • 【Redis】缓存穿透、缓存雪崩、缓存击穿
  • Java 异常处理之 BufferUnderflowException(BufferUnderflowException 概述、常见发生场景、避免策略)
  • C 语言学习笔记(8)
  • 因果推断 | 用SHAP分值等价因果效应值进行反事实推理
  • 【Linux】掌握 setsid:让进程脱离终端独立运行
  • 东芝新四款产品“TB67Z830SFTG、TB67Z830HFTG、TB67Z850SFTG、TB67Z850HFTG系列三相栅极驱动器ic三相栅极驱动器IC
  • 软件测试--入门
  • 【Linux】Ext系列文件系统
  • 鸿蒙-5.1.0-release构建编译环境
  • Oracle中的select1条、几条、指定范围的语句
  • 每日算法-250514
  • 【golang】网络数据包捕获库 gopacket
  • 嵌入式系统中WAV音频文件格式详解与处理实践
  • 【CustomPagination:基于Vue 3与Element Plus的高效二次封装分页器】
  • Lightpanda开源浏览器:专为 AI 和自动化而设计的无界面浏览器
  • 安卓基础(Bitmap)
  • scons user 3.1.2
  • C#强类型枚举的入门理解
  • C++【STL】(2)string
  • 4级流程控制
  • 复现:DemoGen 用于数据高效视觉运动策略学习的 合成演示生成 (RSS) 2025
  • Docker 常见问题及其解决方案
  • nginx报错-[emerg] getpwnam(“nginx“) failed in /etc/nginx/nginx.conf:2
  • FastAPI + OpenAI 模型 的 GitHub 项目结构模板
  • 未来软件开发趋势与挑战
  • Python+Selenium爬虫:豆瓣登录反反爬策略解析
  • C#调用C++dll 过程记录
  • 【VS】VS2019中使用rdlc报表,生成之前修改XML
  • 【每天一个知识点】模型轻量化(Model Compression and Acceleration)技术
  • 解释 RESTful API