热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

tensorflow2.0保存和恢复模型

方法1:只保存模型的权重和偏置这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创

 

方法1:只保存模型的权重和偏置

这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。

tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。

save_weights(filepath,overwrite=True,save_format=None
)

 Arguments:

  • filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.
  • overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.
  • save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.

load_weights(filepath,by_name=False
)

 实例1:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# step2 创建模型
def create_model():return tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(512, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10, activation='softmax')])
model = create_model()# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# step4 模型训练 训练一个epochs
model.fit(x=x_train,y=y_train,epochs=1,)# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')# step7 删除模型
del model# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.55%

Restored model, accuracy:96.55%

可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。

 

方法2:直接保存整个模型

这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。

tf.keras.model类中的save方法和load_model方法

save(filepath,overwrite=True,include_optimizer=True,save_format=None
)

Arguments:

  • filepath: String, path to SavedModel or H5 file to save the model.
  • overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.
  • include_optimizer: If True, save optimizer's state together.
  • save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).

 

实例2:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# step2 创建模型
def create_model():return tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(512, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10, activation='softmax')])
model = create_model()# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# step4 模型训练 训练一个epochs
model.fit(x=x_train,y=y_train,epochs=1,)# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'# step7 删除模型
del model # deletes the existing model# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))

train model, accuracy:96.94%

Restored model, accuracy:96.94%

方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型

该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用

 

ref:

https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/Model#save

https://blog.csdn.net/pan15125284/article/details/93605559

https://blog.csdn.net/qq_31456593/article/details/88829202

https://tensorflow.google.cn/beta/guide/keras/saving_and_serializing

https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/keras/callbacks/ModelCheckpoint


推荐阅读
  • 时域|波形_语音处理基于matlab GUI音频数据处理含Matlab源码 1734期
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了语音处理基于matlabGUI音频数据处理含Matlab源码1734期相关的知识,希望对你有一定的参考价值。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • 开发笔记:图像识别基于主成分分析算法实现人脸二维码识别
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了图像识别基于主成分分析算法实现人脸二维码识别相关的知识,希望对你有一定的参考价值。 ... [详细]
  • keras归一化激活函数dropout
    激活函数:1.softmax函数在多分类中常用的激活函数,是基于逻辑回归的,常用在输出一层,将输出压缩在0~1之间,且保证所有元素和为1,表示输入值属于每个输出值的概率大小2、Si ... [详细]
  • 生成式对抗网络模型综述摘要生成式对抗网络模型(GAN)是基于深度学习的一种强大的生成模型,可以应用于计算机视觉、自然语言处理、半监督学习等重要领域。生成式对抗网络 ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 本文介绍了brain的意思、读音、翻译、用法、发音、词组、同反义词等内容,以及脑新东方在线英语词典的相关信息。还包括了brain的词汇搭配、形容词和名词的用法,以及与brain相关的短语和词组。此外,还介绍了与brain相关的医学术语和智囊团等相关内容。 ... [详细]
  • 本文介绍了设计师伊振华受邀参与沈阳市智慧城市运行管理中心项目的整体设计,并以数字赋能和创新驱动高质量发展的理念,建设了集成、智慧、高效的一体化城市综合管理平台,促进了城市的数字化转型。该中心被称为当代城市的智能心脏,为沈阳市的智慧城市建设做出了重要贡献。 ... [详细]
  • 本文讨论了在Windows 8上安装gvim中插件时出现的错误加载问题。作者将EasyMotion插件放在了正确的位置,但加载时却出现了错误。作者提供了下载链接和之前放置插件的位置,并列出了出现的错误信息。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • Spring特性实现接口多类的动态调用详解
    本文详细介绍了如何使用Spring特性实现接口多类的动态调用。通过对Spring IoC容器的基础类BeanFactory和ApplicationContext的介绍,以及getBeansOfType方法的应用,解决了在实际工作中遇到的接口及多个实现类的问题。同时,文章还提到了SPI使用的不便之处,并介绍了借助ApplicationContext实现需求的方法。阅读本文,你将了解到Spring特性的实现原理和实际应用方式。 ... [详细]
  • ZSI.generate.Wsdl2PythonError: unsupported local simpleType restriction ... [详细]
  • 推荐系统遇上深度学习(十七)详解推荐系统中的常用评测指标
    原创:石晓文小小挖掘机2018-06-18笔者是一个痴迷于挖掘数据中的价值的学习人,希望在平日的工作学习中,挖掘数据的价值, ... [详细]
  • Matlab 中的一些小技巧(2)
    1.Ctrl+D打开子程序  在MATLAB的Editor中,将输入光标放到一个子程序名称中间,然后按Ctrl+D可以打开该子函数的m文件。当然这个子程序要在路径列表中(或在当前工作路径中)。实际上 ... [详细]
author-avatar
吴冲幸福
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有