热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TensorFlow之SessionRunHook

文章目录1.为什么要有Hook?2.Hook有什么用?3.TF内置了哪些Hook?4.TF怎么自定义Hook?5.怎么使用H

文章目录

    • 1. 为什么要有 Hook?
    • 2. Hook 有什么用?
    • 3. TF 内置了哪些 Hook?
    • 4. TF 怎么自定义 Hook?
    • 5. 怎么使用 Hook?
      • 5.1 怎么在 MonitoredTrainingSession 中使用 Hook
      • 5.2 怎么在 Estimator 中使用 Hook
      • 5.3 怎么在 slim 中使用 Hook
    • 6. Hook 是怎么运作的?
    • 7. 内置 Hook 的研究
    • 8. 参考文献


1. 为什么要有 Hook?

SessionRunHook 用来扩展那些将 session 封装起来的高级 API 的 session.run 的行为。

2. Hook 有什么用?

SessionRunHook 对于追踪训练过程、报告进度、实现提前停止等非常有用。

SessionRunHook 以观察者模式运行。SessionRunHook 的设计中有几个非常重要的时间点:

  • session 使用前
  • session.run() 调用之前
  • session.run() 调用之后
  • session 关闭前

SessionRunHook 封装了一些可重用、可组合的计算,并且可以顺便完成 session.run() 的调用。利用 Hook,我们可以为 run() 调用添加任何的 ops或tensor/feeds;并且在 run() 调用完成后获得请求的输出。Hook 可以利用 hook.begin() 方法向图中添加 ops,但请注意:在 begin() 方法被调用后,计算图就 finalized 了。

3. TF 内置了哪些 Hook?

TensorFlow 中已经内置了一些 Hook:

  • StopAtStepHook:根据 global_step 来停止训练。
  • CheckpointSaverHook:保存 checkpoint。
  • LoggingTensorHook:以日志的形式输出一个或多个 tensor 的值。
  • NanTensorHook:如果给定的 Tensor 包含 Nan,就停止训练。
  • SummarySaverHook:保存 summaries 到一个 summary writer。

4. TF 怎么自定义 Hook?

上节,我们已经介绍了预制 Hook,使用其可以实现一些常见功能。如果这些 Hook 不能满足你的需求,那么自定义 Hook 是比较好的选择。

下面是自定义 Hook 的编写模板:

class ExampleHook(tf.train.SessionRunHook):def begin(self):# You can add ops to the graph here.print('Starting the session.')self.your_tensor = ...def after_create_session(self, session, coord):# When this is called, the graph is finalized and# ops can no longer be added to the graph.print('Session created.')def before_run(self, run_context):print('Before calling session.run().')return SessionRunArgs(self.your_tensor)def after_run(self, run_context, run_values): # run_values 为 sess.run 的结果print('Done running one step. The value of my tensor: %s',run_values.results)if you-need-to-stop-loop:run_context.request_stop()def end(self, session):print('Done with the session.')

上面是官方给的解释,下面是我设计的一个设置学习速率的Hook:

class _LearningRateSetterHook(tf.train.SessionRunHook):"""Sets learning_rate based on global step."""def begin(self):self._global_step_tensor &#61; tf.train.get_or_create_global_step()self._lrn_rate_tensor &#61; tf.get_default_graph().get_tensor_by_name(&#39;learning_rate:0&#39;) # 注意&#xff0c;这里根据name来索引tensor&#xff0c;所以请在定义学习速率的时候&#xff0c;为op添加名字self._lrn_rate &#61; 0.1 # 第一阶段的学习速率def before_run(self, run_context):return tf.train.SessionRunArgs(self._global_step_tensor, # Asks for global step value.feed_dict&#61;{self._lrn_rate_tensor: self._lrn_rate}) # Sets learning ratedef after_run(self, run_context, run_values):train_step &#61; run_values.resultsif train_step < 10000:passelif train_step < 20000:self._lrn_rate &#61; 0.01 # 第二阶段的学习速率elif train_step < 30000:self._lrn_rate &#61; 0.001 # 第三阶段的学习速率else:self._lrn_rate &#61; 0.0001 # 第四阶段的学习速率

5. 怎么使用 Hook&#xff1f;

在那些将 session 封装起来的高阶 API 中&#xff0c;我们可以使用 Hook 来扩展这些这些 API 的 session.run() 的行为。

首先&#xff0c;我们梳理一下将 session 封装起来的高阶 API 有哪些&#xff1f;这些 API 包括&#xff0c;但不限于&#xff1a;

  • tf.train.MonitoredTrainingSession&#xff1a;
  • tf.estimator.Estimator&#xff1a;
  • tf.contrib.slim&#xff1a;

5.1 怎么在 MonitoredTrainingSession 中使用 Hook

with tf.train.MonitoredTrainingSession(hooks&#61;your_hooks, ...) as mon_sess:while not mon_sess.should_stop():mon_sess.run(your_fetches)

5.2 怎么在 Estimator 中使用 Hook

tf.estimator.Estimatortrainevaluatepredict 方法中都可以使用 Hook。

下面是这些方法的 API&#xff1a;

# 训练
# 这里的 est 是一个 Estimator 实例
est.train(input_fn, hooks&#61;None, steps&#61;None, max_steps&#61;None, saving_listeners&#61;None)

# 评估
est.evaluate(input_fn, steps&#61;None, hooks&#61;None, checkpoint_path&#61;None, name&#61;None)

# 预测
est.predict(input_fn, predict_keys&#61;None, hooks&#61;None, checkpoint_path&#61;None, yield_single_examples&#61;True)

5.3 怎么在 slim 中使用 Hook

Slim 是 TensorFlow 中一个非常优秀的高阶 API&#xff0c;其可以极大地简化模型的构建、训练、评估。

未完待续。。。。

6. Hook 是怎么运作的&#xff1f;

通过自定义 Hook 的过程&#xff0c;我们了解到一个 Hook 包括 beginafter_create_sessionbefore_runafter_runend 五个方法。

下面的伪代码演示了 Hook 的运行过程&#xff1a;

# 伪代码
call hooks.begin()
sess &#61; tf.Session()
call hooks.after_create_session()
while not stop is requested:call hooks.before_run()try:results &#61; sess.run(merged_fetches, feed_dict&#61;merged_feeds)except (errors.OutOfRangeError, StopIteration):breakcall hooks.after_run()
call hooks.end()
sess.close()

注意&#xff1a;如果 sess.run() 引发 OutOfRangeErrorStopIteration 或其它异常&#xff0c;那么 hooks.after_run()hooks.end() 将不会被执行。

7. 内置 Hook 的研究

预制的 Hook 比较多&#xff0c;这里我们以 tf.train.StopAtStepHook 为例&#xff0c;来看看内置 Hook 是怎么编写的。

# tf.train.StopAtStepHook 的定义
class StopAtStepHook(tf.train.SessionRunHook):"""Hook that requests stop at a specified step."""def __init__(self, num_steps&#61;None, last_step&#61;None):"""Initializes a &#96;StopAtStepHook&#96;.This hook requests stop after either a number of steps have beenexecuted or a last step has been reached. Only one of the two options can bespecified.if &#96;num_steps&#96; is specified, it indicates the number of steps to executeafter &#96;begin()&#96; is called. If instead &#96;last_step&#96; is specified, itindicates the last step we want to execute, as passed to the &#96;after_run()&#96;call.Args:num_steps: Number of steps to execute.last_step: Step after which to stop.Raises:ValueError: If one of the arguments is invalid."""if num_steps is None and last_step is None:raise ValueError("One of num_steps or last_step must be specified.")if num_steps is not None and last_step is not None:raise ValueError("Only one of num_steps or last_step can be specified.")self._num_steps &#61; num_stepsself._last_step &#61; last_stepdef begin(self):self._global_step_tensor &#61; tf.train.get_or_create_global_step()if self._global_step_tensor is None:raise RuntimeError("Global step should be created to use StopAtStepHook.")def after_create_session(self, session, coord):if self._last_step is None:global_step &#61; session.run(self._global_step_tensor)self._last_step &#61; global_step &#43; self._num_stepsdef before_run(self, run_context): # pylint: disable&#61;unused-argumentreturn tf.train.SessionRunArgs(self._global_step_tensor)def after_run(self, run_context, run_values):global_step &#61; run_values.results &#43; 1if global_step >&#61; self._last_step:# Check latest global step to ensure that the targeted last step is# reached. global_step read tensor is the value of global step# before running the operation. We&#39;re not sure whether current session.run# incremented the global_step or not. Here we&#39;re checking it.step &#61; run_context.session.run(self._global_step_tensor)if step >&#61; self._last_step:run_context.request_stop()

8. 参考文献


  1. SessionRunHook 源码&#xff1a;link
  2. tf.train.SessionRunHook() 类详解&#xff1a;link
  3. Hook? tf.train.SessionRunHook()介绍【精】&#xff1a;link

注意&#xff1a;欢迎大家转载&#xff0c;但需注明出处哦
\quad\quad  \;https://blog.csdn.net/u014061630/article/details/82998116


推荐阅读
author-avatar
qaoxiuzcwhyx
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有