2016-12-23 53 views
1

我有以下代碼來執行簡單的算術計算。我試圖通過使用受監視的培訓會話來實現容錯。tensorflow monitoredsession用法

import tensorflow as tf 

global_step_tensor = tf.Variable(10, trainable=False, name='global_step') 

cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223","localhost:2224", "localhost:2225"]}) 
x = tf.constant(2) 

with tf.device("/job:local/task:0"): 
    y1 = x + 300 

with tf.device("/job:local/task:1"): 
    y2 = x**2 

with tf.device("/job:local/task:2"): 
    y3 = 5*x 

with tf.device("/job:local/task:3"): 
    y0 = x - 66 
    y = y0 + y1 + y2 + y3 

ChiefSessionCreator = tf.train.ChiefSessionCreator(scaffold=None, master='localhost:2222', config='grpc://localhost:2222', checkpoint_dir='/home/tensorflow/codes/checkpoints') 
saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir='/home/tensorflow/codes/checkpoints', save_secs=10, save_steps=None, saver=y, checkpoint_basename='model.ckpt', scaffold=None) 
summary_hook = tf.train.SummarySaverHook(save_steps=None, save_secs=10, output_dir='/home/tensorflow/codes/savepoints', summary_writer=None, scaffold=None, summary_op=y) 

with tf.train.MonitoredTrainingSession(master='localhost:2222', is_chief=True, checkpoint_dir='/home/tensorflow/codes/checkpoints', 
    scaffold=None, hooks=[saver_hook, summary_hook], chief_only_hooks=None, save_checkpoint_secs=10, save_summaries_steps=None, config='grpc://localhost:2222') as sess: 

    while not sess.should_stop(): 
     sess.run(model) 

    while not sess.should_stop(): 
     print(sess.run(y0)) 
     print('\n') 

    while not sess.should_stop(): 
     print(sess.run(y1)) 
     print('\n') 

    while not sess.should_stop(): 
     print(sess.run(y2)) 
     print('\n') 

    while not sess.should_stop(): 
     print(sess.run(y3)) 
     print('\n') 

    while not sess.should_stop(): 
     result = sess.run(y) 
     print(result) 

但它拋出以下錯誤: -

Traceback (most recent call last): 
    File "add_1.py", line 36, in <module> 
    scaffold=None, hooks=[saver_hook, summary_hook], chief_only_hooks=None, save_checkpoint_secs=10, save_summaries_steps=None, config='grpc://localhost:2222') as sess: 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 289, in MonitoredTrainingSession 
    return MonitoredSession(session_creator=session_creator, hooks=hooks) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 447, in __init__ 
    self._sess = _RecoverableSession(self._coordinated_creator) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 618, in __init__ 
    _WrappedSession.__init__(self, self._sess_creator.create_session()) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 505, in create_session 
    self.tf_sess = self._session_creator.create_session() 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 341, in create_session 
    init_fn=self._scaffold.init_fn) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/session_manager.py", line 227, in prepare_session 
    config=config) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/session_manager.py", line 153, in _restore_checkpoint 
    sess = session.Session(self._target, graph=self._graph, config=config) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1186, in __init__ 
    super(Session, self).__init__(target, graph, config=config) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 540, in __init__ 
    % type(config)) 
TypeError: config must be a tf.ConfigProto, but got <type 'str'> 
Exception AttributeError: "'Session' object has no attribute '_session'" in <bound method Session.__del__ of <tensorflow.python.client.session.Session object at 0x7fb1bac14ed0>> ignored 

這在我看來是由於給配置不正確的說法。我使用的參數是否正確?請指教。

回答

3

第一個問題在以下幾行。它使用分佈式(設備分配)操作的本地會話。你爲什麼需要這個?

sess = tf.Session() 
tf.train.global_step(sess, global_step_tensor) 

第二期: 代碼使用WorkerSessionCreator。一臺機器應該是首席。在這種情況下,應該使用ChiefSessionCreator。 我建議使用以下tf.train.MonitoredTrainingSession

第三期: sess.should_stop()應該在每個run之前檢查。