热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tensorflow2.0保存和恢复模型

方法1:只保存模型的权重和偏置这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创

 

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(filepath,overwrite=True,save_format=None
)

 Arguments:

  • filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.
  • overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.
  • save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(filepath,by_name=False
)

 实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# step2 创建模型
def create_model():return tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(512, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10, activation='softmax')])
model = create_model()# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# step4 模型训练 训练一个epochs
model.fit(x=x_train,y=y_train,epochs=1,)# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')# step7 删除模型
del model# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

 

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(filepath,overwrite=True,include_optimizer=True,save_format=None
)

Arguments:

  • filepath: String, path to SavedModel or H5 file to save the model.
  • overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.
  • include_optimizer: If True, save optimizer's state together.
  • save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

 

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# step2 创建模型
def create_model():return tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(512, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10, activation='softmax')])
model = create_model()# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# step4 模型训练 训练一个epochs
model.fit(x=x_train,y=y_train,epochs=1,)# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'# step7 删除模型
del model # deletes the existing model# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

 

ref:

https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/Model#save

https://blog.csdn.net/pan15125284/article/details/93605559

https://blog.csdn.net/qq_31456593/article/details/88829202

https://tensorflow.google.cn/beta/guide/keras/saving_and_serializing

https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/callbacks/ModelCheckpoint


推荐阅读
  • 探讨了生成时间敏感的一次性伪随机密码的方法,旨在通过加入时间因素防止重放攻击。 ... [详细]
  • Activity跳转动画 无缝衔接
    Activity跳转动画 无缝衔接 ... [详细]
  • ECharts图表绘制函数集
    本文档提供了使用ECharts库创建柱状图、饼图和双折线图的JavaScript函数。每个函数都详细列出了参数说明,并通过示例展示了如何调用这些函数以生成不同类型的图表。 ... [详细]
  • 本文介绍了一种算法,用于在一个给定的二叉树中找到一个节点,该节点的子树包含最大数量的值小于该节点的节点。如果存在多个符合条件的节点,可以选择任意一个。 ... [详细]
  • MVC框架下使用DataGrid实现时间筛选与枚举填充
    本文介绍如何在ASP.NET MVC项目中利用DataGrid组件增强搜索功能,具体包括使用jQuery UI的DatePicker插件添加时间筛选条件,并通过枚举数据填充下拉列表。 ... [详细]
  • 本文章利用header()函数来实现页面跳,我们介绍到404,302,301等状态跳转哦,下面有很多的状态自定的函数有需要的同学可以测试一下。heade ... [详细]
  • 字符、字符串和文本的处理之Char类型
    .NetFramework中处理字符和字符串的主要有以下这么几个类:(1)、System.Char类一基础字符串处理类(2)、System.String类一处理不可变的字符串(一经 ... [详细]
  • 本文主要解决了在编译CM10.2时出现的关于Samsung Exynos 4 HDMI HAL库中SecHdmiV4L2Utils.cpp文件的编译错误。 ... [详细]
  • 拖拉切割直线 ... [详细]
  • 本文详细介绍如何在Spring Boot项目中集成和使用JPA,涵盖JPA的基本概念、Spring Data JPA的功能以及具体的操作步骤,帮助开发者快速掌握这一强大的持久化技术。 ... [详细]
  • 本文探讨了在Qt框架下实现TCP多线程服务器端的方法,解决了一个常见的问题:服务器端仅能与最后一个连接的客户端通信。通过继承QThread类并利用socketDescriptor标识符,实现了多个客户端与服务器端的同时通信。 ... [详细]
  • 设计模式系列-原型模式
    一、上篇回顾上篇创建者模式中,我们主要讲述了创建者的几类实现方案,和创建者模式的应用的场景和特点,创建者模式适合创建复杂的对象,并且这些对象的每个组成部分的详细创建步骤可以是动态的变化的,但 ... [详细]
  • 抽象工厂模式 c++
    抽象工厂模式包含如下角色:AbstractFactory:抽象工厂ConcreteFactory:具体工厂AbstractProduct:抽象产品Product:具体产品https ... [详细]
  • 本文详细介绍了如何通过配置 Chrome 和 VS Code 来实现对 Vue 项目的高效调试。步骤包括启用 Chrome 的远程调试功能、安装 VS Code 插件以及正确配置 launch.json 文件。 ... [详细]
  • 本文探讨了Codeforces 580C题目——Kefa与公园的问题,深入分析了如何在给定条件下帮助Kefa找到合适的餐厅。 ... [详细]
author-avatar
吴冲幸福
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有