文章目录
- 1. 为什么要有 Hook?
- 2. Hook 有什么用?
- 3. TF 内置了哪些 Hook?
- 4. TF 怎么自定义 Hook?
- 5. 怎么使用 Hook?
- 5.1 怎么在
MonitoredTrainingSession
中使用 Hook - 5.2 怎么在
Estimator
中使用 Hook - 5.3 怎么在
slim
中使用 Hook
- 6. Hook 是怎么运作的?
- 7. 内置 Hook 的研究
- 8. 参考文献
1. 为什么要有 Hook?
SessionRunHook
用来扩展那些将 session
封装起来的高级 API 的 session.run
的行为。
2. Hook 有什么用?
SessionRunHook
对于追踪训练过程、报告进度、实现提前停止等非常有用。
SessionRunHook
以观察者模式运行。SessionRunHook
的设计中有几个非常重要的时间点:
session
使用前session.run()
调用之前session.run()
调用之后session
关闭前
SessionRunHook
封装了一些可重用、可组合的计算,并且可以顺便完成 session.run()
的调用。利用 Hook,我们可以为 run()
调用添加任何的 ops或tensor/feeds;并且在 run()
调用完成后获得请求的输出。Hook 可以利用 hook.begin()
方法向图中添加 ops,但请注意:在 begin()
方法被调用后,计算图就 finalized 了。
3. TF 内置了哪些 Hook?
TensorFlow 中已经内置了一些 Hook:
StopAtStepHook
:根据 global_step 来停止训练。CheckpointSaverHook
:保存 checkpoint。LoggingTensorHook
:以日志的形式输出一个或多个 tensor 的值。NanTensorHook
:如果给定的 Tensor
包含 Nan,就停止训练。SummarySaverHook
:保存 summaries 到一个 summary writer。
4. TF 怎么自定义 Hook?
上节,我们已经介绍了预制 Hook,使用其可以实现一些常见功能。如果这些 Hook 不能满足你的需求,那么自定义 Hook 是比较好的选择。
下面是自定义 Hook 的编写模板:
class ExampleHook(tf.train.SessionRunHook):def begin(self):print('Starting the session.')self.your_tensor = ...def after_create_session(self, session, coord):print('Session created.')def before_run(self, run_context):print('Before calling session.run().')return SessionRunArgs(self.your_tensor)def after_run(self, run_context, run_values): print('Done running one step. The value of my tensor: %s',run_values.results)if you-need-to-stop-loop:run_context.request_stop()def end(self, session):print('Done with the session.')
上面是官方给的解释,下面是我设计的一个设置学习速率的Hook:
class _LearningRateSetterHook(tf.train.SessionRunHook):"""Sets learning_rate based on global step."""def begin(self):self._global_step_tensor &#61; tf.train.get_or_create_global_step()self._lrn_rate_tensor &#61; tf.get_default_graph().get_tensor_by_name(&#39;learning_rate:0&#39;) self._lrn_rate &#61; 0.1 def before_run(self, run_context):return tf.train.SessionRunArgs(self._global_step_tensor, feed_dict&#61;{self._lrn_rate_tensor: self._lrn_rate}) def after_run(self, run_context, run_values):train_step &#61; run_values.resultsif train_step < 10000:passelif train_step < 20000:self._lrn_rate &#61; 0.01 elif train_step < 30000:self._lrn_rate &#61; 0.001 else:self._lrn_rate &#61; 0.0001
5. 怎么使用 Hook&#xff1f;
在那些将 session
封装起来的高阶 API 中&#xff0c;我们可以使用 Hook 来扩展这些这些 API 的 session.run()
的行为。
首先&#xff0c;我们梳理一下将 session
封装起来的高阶 API 有哪些&#xff1f;这些 API 包括&#xff0c;但不限于&#xff1a;
tf.train.MonitoredTrainingSession
&#xff1a;tf.estimator.Estimator
&#xff1a;tf.contrib.slim
&#xff1a;
5.1 怎么在 MonitoredTrainingSession
中使用 Hook
with tf.train.MonitoredTrainingSession(hooks&#61;your_hooks, ...) as mon_sess:while not mon_sess.should_stop():mon_sess.run(your_fetches)
5.2 怎么在 Estimator
中使用 Hook
在 tf.estimator.Estimator
的 train
、evaluate
、predict
方法中都可以使用 Hook。
下面是这些方法的 API&#xff1a;
est.train(input_fn, hooks&#61;None, steps&#61;None, max_steps&#61;None, saving_listeners&#61;None)
est.evaluate(input_fn, steps&#61;None, hooks&#61;None, checkpoint_path&#61;None, name&#61;None)
est.predict(input_fn, predict_keys&#61;None, hooks&#61;None, checkpoint_path&#61;None, yield_single_examples&#61;True)
5.3 怎么在 slim
中使用 Hook
Slim 是 TensorFlow 中一个非常优秀的高阶 API&#xff0c;其可以极大地简化模型的构建、训练、评估。
未完待续。。。。
6. Hook 是怎么运作的&#xff1f;
通过自定义 Hook 的过程&#xff0c;我们了解到一个 Hook 包括 begin
、after_create_session
、before_run
、after_run
、end
五个方法。
下面的伪代码演示了 Hook 的运行过程&#xff1a;
call hooks.begin()
sess &#61; tf.Session()
call hooks.after_create_session()
while not stop is requested:call hooks.before_run()try:results &#61; sess.run(merged_fetches, feed_dict&#61;merged_feeds)except (errors.OutOfRangeError, StopIteration):breakcall hooks.after_run()
call hooks.end()
sess.close()
注意&#xff1a;如果 sess.run()
引发 OutOfRangeError
、StopIteration
或其它异常&#xff0c;那么 hooks.after_run()
和 hooks.end()
将不会被执行。
7. 内置 Hook 的研究
预制的 Hook 比较多&#xff0c;这里我们以 tf.train.StopAtStepHook
为例&#xff0c;来看看内置 Hook 是怎么编写的。
class StopAtStepHook(tf.train.SessionRunHook):"""Hook that requests stop at a specified step."""def __init__(self, num_steps&#61;None, last_step&#61;None):"""Initializes a &#96;StopAtStepHook&#96;.This hook requests stop after either a number of steps have beenexecuted or a last step has been reached. Only one of the two options can bespecified.if &#96;num_steps&#96; is specified, it indicates the number of steps to executeafter &#96;begin()&#96; is called. If instead &#96;last_step&#96; is specified, itindicates the last step we want to execute, as passed to the &#96;after_run()&#96;call.Args:num_steps: Number of steps to execute.last_step: Step after which to stop.Raises:ValueError: If one of the arguments is invalid."""if num_steps is None and last_step is None:raise ValueError("One of num_steps or last_step must be specified.")if num_steps is not None and last_step is not None:raise ValueError("Only one of num_steps or last_step can be specified.")self._num_steps &#61; num_stepsself._last_step &#61; last_stepdef begin(self):self._global_step_tensor &#61; tf.train.get_or_create_global_step()if self._global_step_tensor is None:raise RuntimeError("Global step should be created to use StopAtStepHook.")def after_create_session(self, session, coord):if self._last_step is None:global_step &#61; session.run(self._global_step_tensor)self._last_step &#61; global_step &#43; self._num_stepsdef before_run(self, run_context): return tf.train.SessionRunArgs(self._global_step_tensor)def after_run(self, run_context, run_values):global_step &#61; run_values.results &#43; 1if global_step >&#61; self._last_step:step &#61; run_context.session.run(self._global_step_tensor)if step >&#61; self._last_step:run_context.request_stop()
8. 参考文献
SessionRunHook
源码&#xff1a;linktf.train.SessionRunHook()
类详解&#xff1a;link- Hook? tf.train.SessionRunHook()介绍【精】&#xff1a;link
注意&#xff1a;欢迎大家转载&#xff0c;但需注明出处哦
\quad\quad  \;https://blog.csdn.net/u014061630/article/details/82998116