热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow学习(一):模型保存与恢复

最近在学习Tensorflow,想将所学通过知乎文章的形式进行总结,以飨读者。首先明确一点,tensorflow保存的是什么?模型保存后产生四个文件,分别是:|--models||

最近在学习Tensorflow,想将所学通过知乎文章的形式进行总结,以飨读者。

首先明确一点,tensorflow保存的是什么?

模型保存后产生四个文件,分别是:

|--models
| |--checkpoint
| |--.meta
| |--.data
| |--.index

其中.meta保存的是图的结构,checkpoint文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表,.data和.index保存的是变量值。

即模型保存的是图的结构和变量值。

一 实例

以下是使用tensorflow实现简单的线性模型:

#生成样本数据
x = np.random.randn(10000,1)
y = 0.03*x+0.8
#定义模型参数
Weights = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='Weights')
bias = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='bias')
xx = tf.placeholder(tf.float32,shape=(None,1),name='xx')
yy = tf.placeholder(tf.float32,shape=(None,1),name='yy')
#线性模型
y_predict = tf.add(Weights*xx,bias,name='preds')
#损失函数
loss = tf.reduce_mean(tf.square(yy-y_predict))
#优化方法
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
#批训练模型
batchsize = 20
samplesize = 100
start = 0
end = 0
with tf.Session() as sess:
init_var = tf.global_variables_initializer()
sess.run(init_var)
print('before training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))
for i in (range(5000)):
# start = (i*batchsize)%100
if end == samplesize:
start = 0
end = np.minimum(start+batchsize,samplesize)
# try:
# end = np.min(start+batchsize,samplesize)
# except:
# print(end)
sess.run(optimizer,feed_dict={xx:x[start:end],yy:y[start:end]})
if (i+1)%1000 == 0:
print(sess.run(loss,feed_dict={xx:x[start:end],yy:y[start:end]}))
start += batchsize
print('after training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))

二 模型保存

通过以下程序可实现保存:

saver = tf.train.Saver()
saver.save(session,dir[,global_step])

save中第一个参数是session,第二个参数是模型保存的位置,第三个参数申明模型每迭代多少步保存一次。

保存一中的模型,并设置每1000步保存一次:

#生成样本数据
x = np.random.randn(10000,1)
y = 0.03*x+0.8
#定义模型参数
Weights = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='Weights')
bias = tf.Variable(tf.random_normal([1],seed=1,stddev=1),name='bias')
xx = tf.placeholder(tf.float32,shape=(None,1),name='xx')
yy = tf.placeholder(tf.float32,shape=(None,1),name='yy')
#线性模型
y_predict = tf.add(Weights*xx,bias,name='preds')
#损失函数
loss = tf.reduce_mean(tf.square(yy-y_predict))
#优化方法
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
#批训练模型
batchsize = 20
samplesize = 100
start = 0
end = 0
with tf.Session() as sess:
init_var = tf.global_variables_initializer()
sess.run(init_var)
saver = tf.train.Saver()
print('before training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))
for i in (range(5000)):
# start = (i*batchsize)%100
if end == samplesize:
start = 0
end = np.minimum(start+batchsize,samplesize)
# try:
# end = np.min(start+batchsize,samplesize)
# except:
# print(end)
sess.run(optimizer,feed_dict={xx:x[start:end],yy:y[start:end]})
#实现每1000步保存一次模型
if (i+1)%1000 == 0:
saver.save(sess,'models\ckp',1000)
print(sess.run(loss,feed_dict={xx:x[start:end],yy:y[start:end]}))
start += batchsize
print('after training, variable is %s,%s'%(sess.run(Weights),sess.run(bias)))

以下代码实现了每1000步保存一次模型

if (i+1)%1000 == 0:
saver.save(sess,'models\ckp',1000)

之所以这样做,是为了防止意外情况下(比如训练时突然断电)下次训练需要从头开始训练。

保存的目录结构如下

|--models
| |--checkpoint
| |--ckp-1000.meta
| |--ckp-1000.data-00000-of-00001
| |--ckp-1000.index

三 模型恢复

首先加载保存的meta文件

saver = tf.train.import_meta_graph(file_name)

恢复参数,依赖于session,dir表示模型保存的目录路径,此时所有张量的值都在session中

saver.restore(session,tf.train.latest_checkpoint(dir))

获取恢复的参数,varname表示恢复的参数名,因此建议所有的参数都加上name属性

graph = sess.graph #sess所打开的图,所有的结构都在这个图上
graph.get_tensor_by_name(varname)

以下给出回归模型的恢复,并利用训练好的模型进行预测:

with tf.Session() as sess:
saver = tf.train.import_meta_graph('models\ckp-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('models'))
graph = tf.get_default_graph()
#恢复传入值
xx = graph.get_tensor_by_name('xx:0')
#计算利用训练好的模型参数计算预测值
preds = graph.get_tensor_by_name('preds:0')
print('predict values:%s' % sess.run(preds,feed_dict={xx:x}))


推荐阅读
author-avatar
灰太狼老婆红太狼_715
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有