热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow保存恢复模型及微调

使用tensorflow的过程中,我们常常会用到训练好的模型。我们可以直接使用训练好的模型进行测试或者对训练好的模型做进一步的微调。(微调是指初始化网络参数的时候不再是随机初始化,

使用tensorflow的过程中,我们常常会用到训练好的模型。我们可以直接使用训练好的模型进行测试或者对训练好的模型做进一步的微调。(微调是指初始化网络参数的时候不再是随机初始化,而是使用先前训练好的权重参数进行初始化,在此基础上对网络的全部或者局部参数进行重新训练的过程)。为了实现模型的复用或微调,我将从以下四个方面进行说明:

  • 模型是指什么?
  • 如何保存模型?
  • 如何恢复模型?
  • 如何进行微调?

一、模型是指什么?

tensorflow训练后需要保存的模型主要包含两部分,一是网络图,二是网络图里的参数值。保存的模型文件结构如下(假设每过1000次保存一次):

checkpoint
MyModel-1000.meta
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-2000.meta
MyModel-2000.data-00000-of-00001
MyModel-2000.index
MyModel-3000.meta
MyModel-3000.data-00000-of-00001
MyModel-3000.index
.......

1 checkpoint

checkpoint是一个文本文件,如下所示。其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名。

model_checkpoint_path: "MyModel-3000"
all_model_checkpoint_paths: "MyModel-1000"
all_model_checkpoint_paths: "MyModel-2000"
all_model_checkpoint_paths: "MyModel-3000"
......

2 .meta文件

.meta 文件用于保存网络结构,且以 protocol buffer 格式进行保存。protocol buffer是Google 公司内部使用的一种轻便高效的数据描述语言。类似于XML能够将结构化数据序列化,protocol buffer也可用于序列化结构化数据,并将其用于数据存储、通信协议等方面。相较于XML,protocol buffer更小、更快、也更简单。

3 .data-00000-of-00001 文件和 .index 文件

在tensorflow 0.11之前,保存的文件结构如下。tensorflow 0.11之后,将ckpt文件拆分为了.data-00000-of-00001 和 .index 两个文件。.ckpt是二进制文件,保存了所有变量的值及变量的名称。拆分后的.data-00000-of-00001 保存的是变量值,.index文件保存的是.data文件中数据和 .meta文件中结构图之间的对应关系(也就是变量的名称)

checkpoint
MyModel.meta
MyModel.ckpt

二、如何保存模型?

tensorflow 提供tf.train.Saver类及tf.train.Saver类下面的save方法共同保存模型。下面分别说明tf.train.Saver类及save方法:

tf.train.Saver(var_list=None, reshape=False, sharded=False, max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False,
saver_def=None, builder=None, defer_build=False, allow_empty=False,
write_version=saver_pb2.SaverDef.V2, pad_step_number=False)
就常用的参数进行说明:
var_list:如果我们不对tf.train.Saver指定任何参数,默认会保存所有变量。如果你只想保存一部分变量,
可以通过将需要保存的变量构造list或者dictionary,赋值给var_list。
max_to_keep:tensorflow默认只会保存最近的5个模型文件,如果你希望保存更多,可以通过max_to_keep来指定
keep_checkpoint_every_n_hours:设置每隔几小时保存一次模型
save(sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix="meta",
write_meta_graph=True, write_state=True)
就常用的参数进行说明:
sess:在tensorflow中,只有开启session时数据才会流动,因此保存模型的时候必须传入session。
save_path: 模型保存的路径及模型名称。
global_step:定义每隔多少步保存一次模型,每次会在保存的模型名称后面加上global_step的值作为后缀
write_meta_graph:布尔值,True表示每次都保存图,False表示不保存图(由于图是不变的,没必要每次都去保存)
注意:保存变量的时候必须在session中;保存的变量必须已经初始化;

1.简单示例

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
w3 = tf.Variable(tf.random_normal(shape=[1]), name='w3')
saver = tf.train.Saver()#未指定任何参数,默认保存所有变量。等价于saver = tf.train.Saver(tf.trainable_variables())
save_path = './checkpoint_dir/MyModel'#定义模型保存的路径./checkpoint_dir/及模型名称MyModel
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, save_path)

执行后,在checkpoint_dir目录下创建模型文件如下:

checkpoint
MyModel.data-00000-of-00001
MyModel.index
MyModel.meta

2.经典示例

import tensorflow as tf
from six.moves import xrange
import os
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w11')#变量w1在内存中的名字是w11;恢复变量时应该与name的名字保持一致
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w22')
w3 = tf.Variable(tf.random_normal(shape=[5]), name='w33')
#保存一部分变量[w1,w2];只保存最近的5个模型文件;每2小时保存一次模型
saver = tf.train.Saver([w1, w2],max_to_keep=5, keep_checkpoint_every_n_hours=2)
save_path = './checkpoint_dir/MyModel'#定义模型保存的路径./checkpoint_dir/及模型名称MyModel
# Launch the graph and train, saving the model every 1,000 steps.
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in xrange(100):
if step % 10 == 0:
# 每隔step=10步保存一次模型( keep_checkpoint_every_n_hours与global_step可同时使用,表示'与',通常任选一个就够了);
#每次会在保存的模型名称后面加上global_step的值作为后缀
# write_meta_graph=False表示不保存图
saver.save(sess, save_path, global_step=step, write_meta_graph=False)
# 如果模型文件中没有保存网络图,则使用如下语句保存一张网络图(由于网络图不变,只保存一次就行)
if not os.path.exists('./checkpoint_dir/MyModel.meta'):
# saver.export_meta_graph(filename=None, collection_list=None,as_text=False,export_scope=None,clear_devices=False)
# saver.export_meta_graph()仅仅保存网络图;参数filename表示网络图保存的路径即网络图名称
saver.export_meta_graph('./checkpoint_dir/MyModel.meta')#定义网络图保存的路径./checkpoint_dir/及网络图名称MyModel.meta
#注意:tf.train.export_meta_graph()等价于tf.train.Saver.export_meta_graph()

执行后,在checkpoint_dir目录下创建模型文件如下:

checkpoint
MyModel.meta
MyModel-50.data-00000-of-00001
MyModel-50.index
MyModel-60.data-00000-of-00001
MyModel-60.index
MyModel-70.data-00000-of-00001
MyModel-70.index
MyModel-80.data-00000-of-00001
MyModel-80.index
MyModel-90.data-00000-of-00001
MyModel-90.index

三、如何恢复模型?

tensorflow保存模型时将网络图和网络图里的参数值分开保存。因此,在恢复模型时,也要分为2步:构造网络图和加载参数。

1 构造网络图

构造网络图可以手动创建(需要创建一个跟保存的模型一模一样的网络图)

也可以从meta文件里加载graph进行创建,如下:

#首先恢复graph
saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel.meta')

2 恢复参数有两种方式,如下:

with tf.Session() as sess:
#恢复最新保存的权重
saver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))
#指定一个权重恢复
saver.restore(sess, './checkpoint_dir/MyModel-50')#注意不要加文件后缀名。若权重保存为.ckpt则需要加上后缀

四、如何进行微调?

上面叙述了如何恢复模型,那么,对于恢复出来的模型应该如何使用呢?这里以tensorflow官网给出的vgg为例进行说明。下载地址

恢复出来的模型有四种用途:

  • 查看模型参数
  • 直接使用原始模型进行测试
  • 扩展原始模型(直接使用扩展后的网络进行测试,扩展后需要重新训练的情况见微调部分)
  • 微调:使用训练好的权重参数进行初始化,在此基础上对网络的全部或局部参数进行重新训练

1.查看模型参数

import tensorflow as tf
import vgg
# build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './vgg_16.ckpt') # 权重保存为.ckpt则需要加上后缀
""" 查看恢复的模型参数 tf.trainable_variables()查看的是所有可训练的变量; tf.global_variables()获得的与tf.trainable_variables()类似,只是多了一些非trainable的变量,比如定义时指定为trainable=False的变量; sess.graph.get_operations()则可以获得几乎所有的operations相关的tensor """
tvs = [v for v in tf.trainable_variables()]
print('获得所有可训练变量的权重:')
for v in tvs:
print(v.name)
print(sess.run(v))

gv = [v for v in tf.global_variables()]
print('获得所有变量:')
for v in gv:
print(v.name, '\n')

# sess.graph.get_operations()可以换为tf.get_default_graph().get_operations()
ops = [o for o in sess.graph.get_operations()]
print('获得所有operations相关的tensor:')
for o in ops:
print(o.name, '\n')

2.直接使用原始模型进行测试

import tensorflow as tf
import vgg
import numpy as np
import cv2
image = cv2.imread('./cat.18.jpg')
print(image.shape)
res = cv2.resize(image, (224,224))
res_image = np.expand_dims(res, 0)
print(res_image.shape, type(res_image))
#build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)
print(end_points)
saver = tf.train.Saver()
with tf.Session() as sess:
#恢复权重
saver.restore(sess, './vgg_16.ckpt')#权重保存为.ckpt则需要加上后缀

# Get input and output tensors
# 需要特别注意,get_tensor_by_name后面传入的参数,如果没有重复,需要在后面加上“:0”
# sess.graph等价于tf.get_default_graph()
input = sess.graph.get_tensor_by_name('inputs:0')
output = sess.graph.get_tensor_by_name('vgg_16/fc8/squeezed:0')

# Run forward pass to calculate pred
#使用不同的数据运行相同的网络,只需将新数据通过feed_dict传递到网络即可。
pred = sess.run(output, feed_dict={input:res_image})
#得到使用vgg网络对输入图片的分类结果
print(np.argmax(pred, 1))

3.扩展原始模型

import tensorflow as tf
import vgg
import numpy as np
import cv2
image = cv2.imread('./cat.18.jpg')
print(image.shape)
res = cv2.resize(image, (224, 224))
res_image = np.expand_dims(res, 0)
print(res_image.shape, type(res_image))
# build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)
print(end_points)
saver = tf.train.Saver()
with tf.Session() as sess:
# 恢复权重
saver.restore(sess, './vgg_16.ckpt') # 权重保存为.ckpt则需要加上后缀

# 明确的网络的输入输出,通过get_tensor_by_name()获取变量
input = sess.graph.get_tensor_by_name('inputs:0')
output = sess.graph.get_tensor_by_name('vgg_16/fc8/squeezed:0')

# add more operations to the graph
# 这里只是简单示例,也可以加上新的网络层。
pred = tf.argmax(output, 1)

# 使用不同的数据运行扩展后的网络(这里扩展后的网络不涉及变量,可以直接使用扩展后的网络进行测试)
pred = sess.run(pred, feed_dict={input: res_image})
print(pred)

4.微调

变量ensorflow as tf
import vgg
import numpy as np
import cv2
from skimage import io
import os
# -----------------------------------------准备数据--------------------------------------
#这里以单张图片作为示例,简单说明原理
image = cv2.imread('./cat.18.jpg')
print(image.shape)
res_image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_CUBIC)#vgg_16有全连接层,需要固定输入尺寸
print(res_image.shape)
res_image = np.expand_dims(res_image, axis=0)#网络输入为四维[batch_size, height, width, channels]
print(res_image.shape)
labels = [[1,0]]#标签
# -----------------------------------------恢复图------------------------------------------
#恢复图的方式有很多,这里采用手动构造一个跟保存权重时一样的graph
graph = tf.get_default_graph()
input = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
y_ = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='labels')
# net=[batch, 2]其中2表示二分类,注意官网给出的vgg_16最终的输出没有经过softmax层
net, end_points = vgg.vgg_16(input, num_classes=2) # 保存的权重模型针对的num_classes=1000,这里改为num_classes=2,因此最后一层需要重新训练
print(net, end_points) # net是网络的输出;end_points是所有变量的集合
#add more operations to the graph
y = tf.nn.softmax(net) # 输出0-1之间的概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
output_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='vgg_16/fc8') # 注意这里的scope是定义graph时 name_scope的名字,不要加:0
print(output_vars)
# loss只作用在var_list列表中的变量,也就是说只训练var_list中的变量,其余变量保持不变。若不指定var_list,则默认重新训练所有变量
train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy,var_list=output_vars)
# ----------------------------------------恢复权重------------------------------------------
var = tf.global_variables() # 获取所有变量
print(var)
# var_to_restore = [val for val in var if 'conv1' in val.name or 'conv2' in val.name]#保留变量中含有conv1、conv2的变量
var_to_restore = [val for val in var if 'fc8' not in val.name] # 保留变量名中不含有fc8的变量
print(var_to_restore)
saver = tf.train.Saver(var_to_restore) # 恢复var_to_restore列表中的变量(最后一层变量fc8不恢复)
with tf.Session() as sess:
# restore恢复变量值也是变量初始化的一种方式,对于没有restore的变量需要单独初始化
# 注意如果使用全局初始化,则应在全局初始化后再调用saver.restore()。相当于先通过全局初始化赋值,再通过restore重新赋值。
saver.restore(sess, './vgg_16.ckpt') # 权重保存为.ckpt则需要加上后缀
var_to_init = [val for val in var if 'fc8' in val.name] # 保留变量名中含有fc8的变量
# tf.variable_initializers(tf.global_variables())等价于tf.global_variables_initializer()
sess.run(tf.variables_initializer(var_to_init)) # 没有restore的变量需要单独初始化
# sess.run(tf.global_variables_initializer())
# 用w1,w8测试权重恢复成功没有.正确的情况应该是:w1的值不变,w8的值随机
w1 = sess.graph.get_tensor_by_name('vgg_16/conv1/conv1_1/weights:0')
print(sess.run(w1, feed_dict={input: res_image}))
w8 = sess.graph.get_tensor_by_name('vgg_16/fc8/weights:0')
print('w8', sess.run(w8, feed_dict={input: res_image}))

sess.run(train_op, feed_dict={input:res_image, y_:labels})

五、补充

1 .pb格式的文件

上面提到对于恢复的模型可以直接用来进行测试。对于不再需要改动的模型,我们可以将其保存为.pb格式的文件。

为什么要生成pb文件呢?简单来说就是直接通过tf.saver保存的模型文件其参数和图是分开的。这种形式方便对程序进行微小的改动。但是对于训练好,以后不再需要改动的模型这种形式就不是很必要了。

pb文件就是将变量的值固定下来,直接“烧”到图里面。这个时候只需用户提供一个输入,我们就可以通过模型得到一个输出给用户。pb文件一方面可提供给用户做离线的预测;另一方面,对于线上的模型,一般是通过C++或者C语言编写的程序进行调用。所以模型最终都是写成pb格式的文件。

2 .npy格式的文件

tensorflow保存的模型文件只能在tensorflow框架下使用,不利于将模型权重导入到其他框架使用,同时保存的模型文件无法直接查看。因此经常会考虑转换为.npy格式。.npy文件里的权重值是以数组的形式保存着的,方便查看。

参考:

A quick complete tutorial to save and restore Tensorflow models – CV-Tricks.com

月夜 – 分享网络知识 · 享受快乐生活


推荐阅读
author-avatar
赖雨蓉744_128
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有