你可以簡單地創建一個自定義鉤子並將它傳遞給MonitoredTrainingSession
。無需將您自己的tf.RunMetadata()
實例傳遞給運行調用。
下面是一個例子鉤,其存儲每N個步驟ckptdir元數據:
import tensorflow as tf
class TraceHook(tf.train.SessionRunHook):
"""Hook to perform Traces every N steps."""
def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE):
self._trace = every_step == 1
self.writer = tf.summary.FileWriter(ckptdir)
self.trace_level = trace_level
self.every_step = every_step
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use _TraceHook.")
def before_run(self, run_context):
if self._trace:
options = tf.RunOptions(trace_level=self.trace_level)
else:
options = None
return tf.train.SessionRunArgs(fetches=self._global_step_tensor,
options=options)
def after_run(self, run_context, run_values):
global_step = run_values.results - 1
if self._trace:
self._trace = False
self.writer.add_run_metadata(run_values.run_metadata,
f'{global_step}', global_step)
if not (global_step + 1) % self.every_step:
self._trace = True
它檢查在before_run
它是否有跟蹤與否,如果是,增加了RunOptions。在after_run
它檢查是否需要跟蹤下一個運行調用,如果是,它會再次將_trace
設置爲True。此外,它在元數據可用時存儲元數據。