热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TensorflowFinetune方法

最近搞毕业论文,想直接在pretrain的模型上进行finetune,使用的框架是tensorflow和keras。所以搜索了下,发现keras的finetune方法很简单(后面简

最近搞毕业论文,想直接在pretrain的模型上进行finetune,使用的框架是tensorflow和keras。所以搜索了下,发现keras的finetune方法很简单(后面简单介绍),然而tensorflow的官网也是看得我乱糟糟,google出来的方法也没有合适的,一是找不到model_weights二是没有说明后续添加和训练新层的方法。这里自己折腾了一天,含泪总结下,留下个读书笔记,也希望能帮助到有同样需求的人。

————————————————华丽分割线——————————————————–

预训练权重的作用:

  • 预测
  • 特征提取
  • 微调

Finetune过程:

  • 构建图结构,截取目标张量,添加新层
  • 加载目标张量权重
  • 训练新层
  • 全局微调

1. Keras finetune

Keras的applications模块中就提供了带有预训练权重的深度学习模型。

该模块会根据参数设置,自动检查本地的~/.keras/models/目录下是否含有所需要的权重,没有时会自动下载,在notebook上下载会占用notebook线程资源,不太方便,因此也可以手动wget。

keras应用模块applications,以MobileNet为例说明:

#-----------------------------------构建模型------------------------------------
from keras.applications.mobilenet import MobileNet
from keras.layers import Input, Reshape, AvgPool2D,\
Dropout, Conv2D, Softmax, BatchNormalization, Activation
from keras import Model
## 加载预训练权重,输入大小可以设定,include_top表示是否包括顶层的全连接层
base_model = MobileNet(input_shape=(128,128,3), include_top=False)
## 添加新层,get_layer方法可以根据层名返回该层,output用于返回该层的输出张量tensor
with tf.name_scope("output"):
x = base_model.get_layer("conv_dw_6_relu").output
x = Conv2D(256,kernel_size=(3,3))(x)
x = Activation("relu")(x)
x = AvgPool2D(pool_size=(5,5))(x)
x = Dropout(rate=0.5)(x)
x = Conv2D(10,kernel_size=(1,1),)(x)
predictions = Reshape((10,))(x)
## finetune模型
model = Model(inputs=base_model.input, outputs=predictions)
#-------------------------------------训练新层-----------------------------------
## 冻结原始层位,在编译后生效
for layer in base_model.layers:
layer.trainable = False
## 设定优化方法,并编译
sgd = keras.optimizers.SGD(lr=0.01)
model.compile(optimizer=sgd,loss="categorical_crossentropy")
‘’‘可选记录模型训练过程数据写入tensorboard
callback = [keras.callbacks.ModelCheckpoint(filepath="./vibration_keras/checkpoints",monitor="val_loss"),
keras.callbacks.TensorBoard(log_dir="./vibration_keras/logs",histogram_freq=1,write_grads=True)]
’‘’
## 训练
model.fit(...)
#--------------------------------------全局微调-----------------------------------
## 设定各层是否用于训练,编译后生效
for layer in model.layers[:80]:
layer.trainable = False
for layer in model.layers[80:]:
layer.trainable = True
# 设定优化方法,并编译
sgd = keras.optimizer.SGD(lr=0.0001)
model.compile(optimizer=sgd, loss="categorical_crossentropy")
## 训练
model.fit(...)

获取各层名称的方法:

《Tensorflow Finetune方法》
《Tensorflow Finetune方法》 获取各层名称

2. Tensorflow finetune

tensorflow的finetune方法有3种:

  • 利用tf-slim中构建好的网络结构和权重,手动调整
  • 利用tf-slim提供的train_image_classifier.py脚本自动化构建,具体方法这里
  • 利用tf.keras,过程与keras相同

这里主要介绍上面的第一种方法,注意事项:

  • tensorflow/models在1.0版本后从tf主框架移除,需要手动下载,位置在这里tensorflow/models,可以使用git clone下载到本地目录下,使用时使用下面命令临时添加到python搜索路径

import sys
sys.path.append("./models/research/slim")

  • tf-slim的预训练网络的checkpoint文件在tensorflow/models/research/slim里,常见网络预训练权重
  • mobilenet预训练网络的checkpoint文件在slim/nets/mobilenet里面列举得更具体,Mobilenet权重

2.1 模型构建方法

tensorflow有3种方法,从checkpoint文件中恢复模型结构和权重,这里列出的模型恢复后都可以直接进行前向推导计算,进行预测分析。

1) 直接加载图结构,再加载权重

# import_meta_graph可以直接从meta文件中加载图结构
saver = tf.train.import_meta_graph(os.path.join(model_path,r"resnet_v2/model.ckpt-258931.meta"))
# allow_soft_placement自动选择设备
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
# latest_checkpoint检查checkpoint检查点文件,查找最新的模型
# restore恢复图权重
saver.restore(sess,tf.train.latest_checkpoint(r"./model/resnet_v2"))
graph = sess.graph
# get_tensor_by_name通过张量名称获取张量
print(sess.run(graph.get_tensor_by_name("resnet_model/conv2d/kernel:0")))

2)先构建图结构,再加载权重

# 临时添加slim到python搜索路径
import sys
sys.path.append("./models/research/slim")
# 导入mobilenet_v2
from nets.mobilenet import mobilenet_v2
# 重置图
tf.reset_default_graph()
# 导入mobilenet,先构建图结构。
‘’‘加载完毕后tf.get_default_graph()中包含了mobilenet计算图结构
可以使用tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)对比reset_graph前后的差异
’‘’
images = tf.placeholder(tf.float32,(None,224,224,3))
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):
logits, endpoints = mobilenet_v2.mobilenet(images,depth_multiplier=1.4)
# 定义saver类,用于恢复图权重
saver = tf.train.Saver()
with tf.Session() as sess:
# latest_checkpoint检查checkpoint检查点文件,查找最新的模型
# restore恢复图权重
saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))
# get_tensor_by_name通过张量名称获取张量
print(sess.run(tf.get_default_graph().get_tensor_by_name("MobilenetV2/Conv/weights:0")).shape)

输出计算图结构内的节点名称,张量名称后面要有:0之类的数字,表示某个计算节点的第一个输出:

《Tensorflow Finetune方法》
《Tensorflow Finetune方法》 所有可训练变量名称

3) frozen inference

在测试或离线预测时,仅需要神经网络前向推导过程的计算逻辑,而不需要变量初始化、模型保存等辅助节点信息,并且不需要Saver()函数一样将变量和计算图结构分开储存。pb文件将变量取值和计算图整个结构统一放在一个文件中,通过convert_variable_to_constants将变量及取值转化为常量保存,在模型测试的时候,输入只需要经过前向传播至输出层就可以。

# 读取保存的pb文件,并解析成对应的GraphDef Protocol Buffer
gd = tf.GraphDef.FromString(open('./model_ckpt/mobilenet_v2/mobilenet_v2_1.4_224_frozen.pb',"rb").read())
# import_graph_def将graphdef中保存的图加载到当前图中,return_elements返回指定张量
inp, predictions = \
tf.import_graph_def(gd,return_elements=["input:0","MobilenetV2/Predictions/Reshape_1:0"])
# 此时的计算图可以直接用于预测
# 拉取一张图片 !wget https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG -O panda.jpg
from PIL import Image
img = np.array(Image.open('panda.jpg').resize((224, 224))).astype(np.float) / 128 - 1
# inp是需要feed的输入,predictions是需要输出的预测结构
with tf.Session(graph=inp.graph) as sess:
x = sess.run(predictions,feed_dict={inp:img.reshape(1,224,224,3)})

2.2 Finetune过程

  • 1) 构建图结构,截取目标张量,添加新层
  • 2) 加载目标张量权重
  • 3) 训练新层
  • 4) 全局微调

1) 构建图结构,截取目标张量,添加新层

这个步骤中的图结构,是通过“先构建图结构,再加载权重方法得到的mobilenet计算图结构。

tf.reset_default_graph()
# 构建计算图
images = tf.placeholder(tf.float32,(None,224,224,3))
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=False)):
logits, endpoints = mobilenet_v2.mobilenet(images,depth_multiplier=1.4)
# 获取目标张量,添加新层
with tf.variable_scope("finetune_layers"):
# 获取目标张量,取出mobilenet中指定层的张量
mobilenet_tensor = tf.get_default_graph().get_tensor_by_name("MobilenetV2/expanded_conv_14/output:0")
# 将张量向新层传递
x = tf.layers.Conv2D(filters=256,kernel_size=3,name="conv2d_1")(mobilenet_tensor)
# 观察新层权重是否更新 tf.summary.histogram("conv2d_1",x)
x = tf.nn.relu(x,name="relu_1")
x = tf.layers.Conv2D(filters=256,kernel_size=3,name="conv2d_2")(x)
x = tf.layers.Conv2D(10,3,name="conv2d_3")(x)
predictions = tf.reshape(x, (-1,10))

计算图结构:

《Tensorflow Finetune方法》
《Tensorflow Finetune方法》 finetune网络结构

红色框内的是Mobilenet网络结构,由上至下的第二个紫色节点为”MobilenetV2/expanded_conv_14/output”节点,可以看出直接与finetune_layers相接。

2) 加载目标权重,训练新层

# one-hot编码
def to_categorical(data, nums):
return np.eye(nums)[data]
# 随机生成数据
x_train = np.random.random(size=(141,224,224,3))
y_train = to_categorical(label_fake,10)
# 训练条件配置
## label占位符
y_label = tf.placeholder(tf.int32, (None,10))
## 收集变量作用域finetune_layers内的变量,仅更新添加层的权重
train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope="finetune_layers")
## 定义loss
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_label,logits=predictions)
## 定义优化方法,用var_list指定需要更新的权重,此时仅更新train_var权重
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss,var_list=train_var)
## 观察新层权重是否更新
tf.summary.histogram("mobilenet_conv8",tf.get_default_graph().get_tensor_by_name('MobilenetV2/expanded_conv_8/depthwise/depthwise_weights:0'))
tf.summary.histogram("mobilenet_conv9",tf.get_default_graph().get_tensor_by_name('MobilenetV2/expanded_conv_9/depthwise/depthwise_weights:0'))
## 合并所有summary
merge_all = tf.summary.merge_all()
## 设定迭代次数和批量大学
epochs = 10
batch_size = 16
# 获取指定变量列表var_list的函数
def get_var_list(target_tensor=None):
'''获取指定变量列表var_list的函数'''
if target_tensor==None:
target_tensor = r"MobilenetV2/expanded_conv_14/output:0"
target = target_tensor.split("/")[1]
all_list = []
all_var = []
# 遍历所有变量,node.name得到变量名称
# 不使用tf.trainable_variables(),因为batchnorm的moving_mean/variance不属于可训练变量
for var in tf.global_variables():
if var != []:
all_list.append(var.name)
all_var.append(var)
try:
all_list = list(map(lambda x:x.split("/")[1],all_list))
# 查找对应变量作用域的索引
ind = all_list[::-1].index(target)
ind = len(all_list) - ind - 1
print(ind)
del all_list
return all_var[:ind+1]
except:
print("target_tensor is not exist!")
# 目标张量名称,要获取一个需要从文件中加载权重的变量列表var_list
target_tensor = "MobilenetV2/expanded_conv_14/output:0"
var_list = get_var_list(target_tensor)
saver = tf.train.Saver(var_list=var_list)
# 加载文件内的权重,并训练新层
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
writer = tf.summary.FileWriter(r"./logs", sess.graph)
## 初始化参数:从文件加载权重 train_var使用初始化函数
sess.run(tf.variables_initializer(var_list=train_var))
saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2")) for i in range(2000):
start = (i*batch_size) % x_train.shape[0]
end = min(start+batch_size, x_train.shape[0])
_, merge, losses = sess.run([train_step,merge_all,loss],\
feed_dict={images:x_train[start:end],\
y_label:y_train[start:end]})
if i%100==0:
writer.add_summary(merge, i)

权重初始化注意事项:

1.先利用全局初始化tf.global_variables_initializer(),再利用saver.restore顺序不能错,否则加载的权重会被重新初始化 。

sess.run(tf.global_variables_initializer())
saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))

2.先利用saver.restore从模型中加载权重,再利用tf.variable_initializaer()初始化指定的var_list,顺序可以调换.

saver.restore(sess,tf.train.latest_checkpoint("./model_ckpt/mobilenet_v2"))
sess.run(tf.variables_initializer(var_list=train_var))

3.前两种方法会对无用的节点也进行变量初始化,并且需要提前进行saver.restore操作,也就是说需要两次save.restore操作,才能保证finetune过程不会报错。现在可以通过筛选出需要从文件中加载权重的所有变量组成var_list,然后定义saver=tf.train.Saver(var_list),选择性的加载变量.

上面代码使用了第三种方法,以上3种初始化方法的差异可以仔细体会下。可以仔细看看下面的截图。

《Tensorflow Finetune方法》
《Tensorflow Finetune方法》 初始化不同造成的差异

——————————————–20180817更新———————————————

1) 添加了关于变量初始化3种方法的思考

2) 调整下文章逻辑,2.1模型构建主要用于前向传播推导,2.2阐述了完整的finetune过程

欢迎转载~请注明出处~谢谢~


推荐阅读
  • C语言是计算机科学和编程领域的基石,许多初学者在学习过程中会感到困惑。本文将详细介绍C语言的基本概念、关键语法和实用示例,帮助你快速上手C语言。 ... [详细]
  • 短视频app源码,Android开发底部滑出菜单首先依赖三方库implementationandroidx.appcompat:appcompat:1.2.0im ... [详细]
  • 本文整理了一份基础的嵌入式Linux工程师笔试题,涵盖填空题、编程题和简答题,旨在帮助考生更好地准备考试。 ... [详细]
  • 本文介绍了 Go 语言中的高性能、可扩展、轻量级 Web 框架 Echo。Echo 框架简单易用,仅需几行代码即可启动一个高性能 HTTP 服务。 ... [详细]
  • python模块之正则
    re模块可以读懂你写的正则表达式根据你写的表达式去执行任务用re去操作正则正则表达式使用一些规则来检测一些字符串是否符合个人要求,从一段字符串中找到符合要求的内容。在 ... [详细]
  • 兆芯X86 CPU架构的演进与现状(国产CPU系列)
    本文详细介绍了兆芯X86 CPU架构的发展历程,从公司成立背景到关键技术授权,再到具体芯片架构的演进,全面解析了兆芯在国产CPU领域的贡献与挑战。 ... [详细]
  • 2020年9月15日,Oracle正式发布了最新的JDK 15版本。本次更新带来了许多新特性,包括隐藏类、EdDSA签名算法、模式匹配、记录类、封闭类和文本块等。 ... [详细]
  • 本文详细解析了 Python 2.x 版本中 `urllib` 模块的核心功能与应用实例,重点介绍了 `urlopen()` 和 `urlretrieve()` 方法的使用技巧。其中,`urlopen()` 方法用于发送网络请求并获取响应内容,而 `urlretrieve()` 方法则用于下载文件并保存到本地。文章通过具体示例展示了这两个方法在实际开发中的应用场景,帮助读者更好地理解和掌握 `urllib` 模块的使用。 ... [详细]
  • 在Python编程中,探讨了并发与并行的概念及其区别。并发指的是系统同时处理多个任务的能力,而并行则指在同一时间点上并行执行多个任务。文章详细解析了阻塞与非阻塞操作、同步与异步编程模型,以及IO多路复用技术的应用。通过模拟socket发送HTTP请求的过程,展示了如何创建连接、发送数据和接收响应,并强调了默认情况下socket的阻塞特性。此外,还介绍了如何利用这些技术优化网络通信性能和提高程序效率。 ... [详细]
  • ipsec 加密流程(二):ipsec初始化操作
    《openswan》专栏系列文章主要是记录openswan源码学习过程中的笔记。Author:叨陪鲤Email:vip_13031075266163.comDate:2020.1 ... [详细]
  • Android异步处理一:使用Thread+Handler实现非UI线程更新UI界面Android异步处理二:使用AsyncTask异步更新UI界面Android异步处理三:Handler+Loope ... [详细]
  • 在一个整型数组中,除了两个数字只出现一次外,其他所有数字都出现了两次。编写一个程序来找出这两个只出现一次的数字。 ... [详细]
  • 本文将详细介绍如何在Android Studio中导入和编译OSChina Android 2.4版本的源码。包括所需软件、下载地址以及一些注意事项。 ... [详细]
  • 自然语言处理(NLP)——LDA模型:对电商购物评论进行情感分析
    目录一、2020数学建模美赛C题简介需求评价内容提供数据二、解题思路三、LDA简介四、代码实现1.数据预处理1.1剔除无用信息1.1.1剔除掉不需要的列1.1.2找出无效评论并剔除 ... [详细]
  • Leetcode学习成长记:天池leetcode基础训练营Task01数组
    前言这是本人第一次参加由Datawhale举办的组队学习活动,这个活动每月一次,之前也一直关注,但未亲身参与过,这次看到活动 ... [详细]
author-avatar
姚姚姚YTLLL
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有