2017-08-06 54 views

回答

3

你可以找到一個教程here,有點長,但你能跳建設網絡的一部分。或者你可以根據我的經驗閱讀下面的小小總結。

首先,MonitoredSession應該被用來代替正常Session

SessionRunHook延伸session.run()要求MonitoredSession

然後一些常見SessionRunHook類可以發現here。一個簡單的一種是LoggingTensorHook,但您可能希望進口後添加以下行運行時看到日誌:

tf.logging.set_verbosity(tf.logging.INFO) 

或者你有選項來實現自己的SessionRunHook類。一個簡單的一個是從cifar10 tutorial

class _LoggerHook(tf.train.SessionRunHook): 
    """Logs loss and runtime.""" 

    def begin(self): 
    self._step = -1 
    self._start_time = time.time() 

    def before_run(self, run_context): 
    self._step += 1 
    return tf.train.SessionRunArgs(loss) # Asks for loss value. 

    def after_run(self, run_context, run_values): 
    if self._step % FLAGS.log_frequency == 0: 
     current_time = time.time() 
     duration = current_time - self._start_time 
     self._start_time = current_time 

     loss_value = run_values.results 
     examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size/duration 
     sec_per_batch = float(duration/FLAGS.log_frequency) 

     format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
        'sec/batch)') 
     print (format_str % (datetime.now(), self._step, loss_value, 
          examples_per_sec, sec_per_batch)) 

loss其中被定義在類外。這_LoggerHook使用print打印信息,同時LoggingTensorHook使用tf.logging.INFO

最後,爲了更好的理解它是如何工作的,執行順序是由僞帶有MonitoredSessionhere

call hooks.begin() 
    sess = tf.Session() 
    call hooks.after_create_session() 
    while not stop is requested: # py code: while not mon_sess.should_stop(): 
    call hooks.before_run() 
    try: 
     results = sess.run(merged_fetches, feed_dict=merged_feeds) 
    except (errors.OutOfRangeError, StopIteration): 
     break 
    call hooks.after_run() 
    call hooks.end() 
    sess.close() 

希望這有助於。

+0

感謝詳細的解釋。 – gaussclb

相關問題