热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

我是如何利用tensorflowmodel_pruning模块对模型算法进行剪枝的

什么是剪枝剪枝就是利用某一个准则对某一组或某一个权值置0从而达到将网络神经元置0以达到稀疏化网络连接从而加快整个推理过程及缩小模型大小的迭代过程,这个准则有暴力穷尽组合排忧、使用对

什么是剪枝

剪枝就是利用某一个准则对某一组或某一个权值置0从而达到将网络神经元置0以达到稀疏化网络连接从而加快整个推理过程及缩小模型大小的迭代过程,这个准则有暴力穷尽组合排忧、使用对角 Hessian 逼近计算每个权值的重要性、基于一阶泰勒展开的模型代价函数来对权值排序、基于L1绝对值的权值参数大小进行排序、基于在小验证集上的影响进行分值分配排序等方法,而某一组或某一个网络权值则可以是整个卷积核、全连接层、卷积核或全连接层上的某个权重参数,剪枝的目的是将冗余的神经元参数置0减小模型大小(需要特殊的模型存储方式)减少计算参数(需要某种特殊的硬件计算方式)稀疏化网络连接加快推理速度,剪枝前后的网络连接对比图如下:

《我是如何利用tensorflow model_pruning模块对模型算法进行剪枝的》
《我是如何利用tensorflow model_pruning模块对模型算法进行剪枝的》

tensorflow model_pruning简介

其源代码位于tensorflow/tensorflow/contrib/model_pruning

tensorflow model_pruning使用的剪枝准则是根据每个神经元的L1绝对值的权重参数大小进行排序之后将低于某一个阈值threshold的权重参数全部直接置0,而其实现方法则是在图结构里添加剪枝Ops,设置有一个与权重形状一致的二进制掩模(mask)变量,在前向传播时该掩模的对应位与选中权重进行相与输出feature map,如果该掩模对应位为0则对应的权重相与后则为0,在反向传播时掩模对应位为0的权重参数则不参与更新,在保存模型时则可以通过去掉剪枝Ops的方式直接稀疏化权重,这样就起到了稀疏连接的作用,添加了pruning ops的图结构(注意mask和threshold这两个差异)如下:

《我是如何利用tensorflow model_pruning模块对模型算法进行剪枝的》
《我是如何利用tensorflow model_pruning模块对模型算法进行剪枝的》

model_pruning的超参数如下:

《我是如何利用tensorflow model_pruning模块对模型算法进行剪枝的》
《我是如何利用tensorflow model_pruning模块对模型算法进行剪枝的》

稀疏度计算公式如下:

St = Sf + (Si – Sf) * (1 – (T – t0) / nt)^e,(T等于t0,t0+t,t0+2t,t0+3t,…)

其中,每个符号对应的超参数如下:

Sf :target_sparsity

Si:initial_sparsity

t0:sparsity_function_begin_step

n :sparsity_function_end_step – sparsity_function_end_begin

t:pruning_frequency

e:sparsity_function_exponent

官方提供的使用model_pruning例子如下:

tf.app.flags.DEFINE_string(
'pruning_hparams', '',
"""Comma separated list of pruning-related hyperparameters""")
with tf.graph.as_default():
# Create global step variable
global_step = tf.train.get_or_create_global_step()
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
# Create a pruning object using the pruning specification
p = pruning.Pruning(pruning_hparams, global_step=global_step)
# Add conditional mask update op. Executing this op will update all
# the masks in the graph if the current global step is in the range
# [begin_pruning_step, end_pruning_step] as specified by the pruning spec
mask_update_op = p.conditional_mask_update_op()
# Add summaries to keep track of the sparsity in different layers during training
p.add_pruning_summaries()
with tf.train.MonitoredTrainingSession(...) as mon_sess:
# Run the usual training op in the tf session
mon_sess.run(train_op)
# Update the masks by running the mask_update_op
mon_sess.run(mask_update_op)

特别需要注意的是:

一定要保证传给pruning的global_step是随着训练迭代保持增长的,否则不会产生剪枝效果!

去除剪枝Ops的方法如下:

$ bazel build -c opt contrib/model_pruning:strip_pruning_vars
$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_dir=/path/to/checkpoints/ --output_node_names=graph_node1,graph_node2 --output_dir=/tmp --filename=pruning_stripped.pb

我是如何利用model_pruning模块进行模型剪枝的:

其实,添加剪枝Ops的方式基本如官方说明,主要区别在于用了tf.train.MonitoredTrainingSession这个可monitor会话的接口,反而没有之前那么好操作了,但是该接口实际上比tf.Session则有不少好处,它会负责初始化变量、启动队列及建立summary文件操作,但是使用该接口的例子实在太少了,基本上找不到相关的说明。

我通过研究源码,基本上满足了整个剪枝图结构的构造:

1、从tfrecord文件里获取batch size数据

tfrecords_f = os.path.join(args.tfrecords_file_path, 'tran.tfrecords')
dataset = tf.data.TFRecordDataset(tfrecords_f)
dataset = dataset.map(parse_function)
dataset = dataset.shuffle(buffer_size=args.buffer_size)
dataset = dataset.batch(args.batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

由于用了dataset.make_initializable_iterator()迭代器,则只能通过抓取tf.errors.OutOfRangeError错误来判断每一轮迭代的完成。

2、建立monitor会话

init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
scaffold = tf.train.Scaffold(init_op=init_op, local_init_op=local_init_op)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=None,
hooks=[tf.train.StopAtStepHook(last_step=args.epoch), tf.train.NanTensorHook(total_loss), __SessionRunHook()], #
cOnfig=config,
scaffold=scaffold,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None) as mon_sess:
while not mon_sess.should_stop():

我们主要关注这个hooks钩子函数,tf.train.StopAtStepHook用于结束训练,tf.train.NanTensorHook用于抓取Nan错误并报错停止,而__SessionRunHook()则是我建立的一个继承自tf.train.SessionRunHook的类。

3、__SessionRunHook()根据自身项目情况来设置

3-1)begin()用于在整个会话启动循环前的配置

def begin(self):
print("begin")
# 4 begin iteration
if not os.path.exists(args.log_file_path):
os.makedirs(args.log_file_path)
log_file_path = args.log_file_path + '/train' + time.strftime('_%Y-%m-%d-%H-%M',
time.localtime(time.time())) + '.log'
self.log_file = open(log_file_path, 'w')
self.total_accuracy = {}
self._step = -1
# 3.11 summary writer
summaries = []
# # 3.11.1 add grad histogram op
#for grad, var in grads:
# if grad is not None:
# summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad))
# 3.11.2 add trainabel variable gradients
for var in tf.trainable_variables():
summaries.append(tf.summary.histogram(var.op.name, var))
# 3.11.3 add loss summary
summaries.append(tf.summary.scalar('inference_loss', inference_loss))
summaries.append(tf.summary.scalar('wd_loss', wd_loss))
summaries.append(tf.summary.scalar('total_loss', total_loss))
# 3.11.4 add learning rate
summaries.append(tf.summary.scalar('leraning_rate', lr))
self.summary_op = tf.summary.merge(summaries)

主要用于建立log文件句柄,一些参数的初始化,summaries的收集定义等。

3-2)after_create_session则是建立会话后调用的钩子

def after_create_session(self, session, coord):
print("after_create_session")
self.summary = tf.summary.FileWriter(args.summary_path, session.graph)
self.sess = session

用于获取summary及session。

3-3)mon_sess.run启动前做一些辅助工作

def before_run(self, run_context):
#print("before_run")
self._step += 1
self._start_time = time.time()
global init_iter
if init_iter == True:
print("init datasets iterator!")
self.sess.run(iterator.initializer)
init_iter = False
self.images_train, self.labels_train = self.sess.run(next_element)
feed_dict = {images: self.images_train, labels: self.labels_train}
fetches = [total_loss, inference_loss, wd_loss, inc_op, acc]
return tf.train.SessionRunArgs(fetches=fetches, feed_dict=feed_dict)

是否要初始化数据迭代器,喂给会话的字典建立等,该接口会通过传入形参的形式返回会话运行结果到after_run。

3-4)after_run则是会话运行完后需要处理的钩子

def after_run(self, run_context, run_values):
#print("after_run")
duration = time.time() - self._start_time
run_values = run_values.results
total_loss_val, inference_loss_val, wd_loss_val, _, acc_val = run_values
pre_sec = args.batch_size / duration
global epoch_iter
# print training information
if self._step > 0 and self._step % args.show_info_interval == 0:
print('epoch %d, total_step %d, total loss is %.2f , inference loss is %.2f, weight deacy '
'loss is %.2f, t:265raining accuracy is %.6f, time %.3f samples/sec' %
(epoch_iter, self._step, total_loss_val, inference_loss_val, wd_loss_val, acc_val, pre_sec))
# test model at first time
global test_flag
if test_flag == True:
test_flag = False
feed_dict_test = {}
# feed_dict_test.update(tl.utils.dict_to_one(net.all_drop))
results = ver_test(ver_list=ver_list, ver_name_list=ver_name_list, nbatch=self._step, sess=self.sess,
embedding_tensor=embedding_tensor, batch_size=args.batch_size,
feed_dict=feed_dict_test,
input_placeholder=images)
print('first test accuracy is: ', str(results[0]))
# save summary
if self._step > 0 and self._step % args.summary_interval == 0:
feed_dict = {images: self.images_train, labels: self.labels_train}
# feed_dict.update(net.all_drop)
summary_op_val = self.sess.run(self.summary_op, feed_dict=feed_dict)
self.summary.add_summary(summary_op_val, self._step)
# save ckpt files
if self._step > 0 and self._step % args.ckpt_interval == 0:
filename = 'InsightFace_iter_{:d}'.format(self._step) + '.ckpt'
filename = os.path.join(args.ckpt_path, filename)
saver.save(self.sess, filename)
# validate
if self._step > 0 and self._step % args.validate_interval == 0:
feed_dict_test = {}
# feed_dict_test.update(tl.utils.dict_to_one(net.all_drop))
results = ver_test(ver_list=ver_list, ver_name_list=ver_name_list, nbatch=self._step, sess=self.sess,
embedding_tensor=embedding_tensor, batch_size=args.batch_size,
feed_dict=feed_dict_test,
input_placeholder=images)
print('test accuracy is: ', str(results[0]))
self.total_accuracy[str(self._step)] = results[0]
self.log_file.write('########' * 10 + '\n')
self.log_file.write(','.join(list(self.total_accuracy.keys())) + '\n')
self.log_file.write(','.join([str(val) for val in list(self.total_accuracy.values())]) + '\n')
self.log_file.flush()
if max(results) > 0.997:
print('best accuracy is %.5f' % max(results))
filename = 'InsightFace_iter_best_{:d}'.format(self._step) + '.ckpt'
filename = os.path.join(args.ckpt_path, filename)
saver.save(self.sess, filename)
self.log_file.write('######Best Accuracy######' + '\n')
self.log_file.write(str(max(results)) + '\n')
self.log_file.write(filename + '\n')
self.log_file.flush()

主要做了获取会话运行结果并根据需要打印结果,进行测试验证,保存节点信息到tensorboard文件,按间隔保存checkpoint文件等。

3-5)end则是整个迭代运行完后需要处理的事项

def end(self, session):
print("end")
self.log_file.close()
self.log_file.write('\n')

只有在钩子函数里做好了我们想要在整个训练迭代过程中的工作后,我们才可以安然无恙的运行如下类似代码而不用再处理一些过程:

with tf.train.MonitoredTrainingSession(
checkpoint_dir=None,
hooks=[tf.train.StopAtStepHook(last_step=args.epoch), tf.train.NanTensorHook(total_loss), __SessionRunHook()], #
cOnfig=config,
scaffold=scaffold,
save_checkpoint_secs=None,
save_summaries_steps=None,
save_summaries_secs=None) as mon_sess:
while not mon_sess.should_stop():
#mon_sess.run(iterator.initializer)
epoch_iter += 1
if pretrained_model and restore_flag:
print('Restoring pretrained model: %s' % pretrained_model)
ckpt = tf.train.get_checkpoint_state(pretrained_model)
print(ckpt)
saver.restore(mon_sess, ckpt.model_checkpoint_path)
restore_flag = False
while True:
try:
mon_sess.run(train_op)
# Update the masks
mon_sess.run(mask_update_op)
except tf.errors.OutOfRangeError as e:
print("OutOfRangeError")
print(e)
init_iter = True
break

以上,就是我是如何利用tensorflow model_pruning模块对模型算法进行剪枝的一个参考例子,特别是在使用到了tf.train.MonitoredTrainingSession的项目中可作为如何配置hooks的一个参考。


推荐阅读
author-avatar
凯鹏2502896277
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有