2017-06-07 75 views
0

我將爲單個節點運行創建一個Tensorflow圖形。但後來,如果我想在分佈式環境中訓練相同的模型圖(在多個參數服務器之間劃分變量並在n個工作人員之間複製圖),我該怎麼做?如何將在非分佈式環境中創建的Tensorflow圖加載到分佈式環境?

我找到了一個叫tf.Graph.as_graph_def()的地方導出GraphDef原型,後來導入圖表爲tf.import_graph_def()。但是這不起作用。

代碼:

import tensorflow as tf 

graph = tf.Graph() 

with graph.as_default(): 
    x_place_holder = tf.placeholder(dtype=tf.float32, shape=[], name="xin") 
    y_place_holder = tf.placeholder(dtype=tf.float32, shape=[], name="yin") 

    m = tf.Variable(10.0, name="varm") 
    l = tf.Variable(20.0, name="varl") 

    Y = tf.multiply(m, x_place_holder, name="mulop") 
    X = tf.add(l, x_place_holder, name="addop") 
    cost = tf.abs(Y - X, name="cost") 

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5, name="optimizer").minimize(cost) 

tf.reset_default_graph() 

if FLAGS.job_name == "ps": 
    server.join() 

elif FLAGS.job_name == "worker": 
    print(FLAGS.task_index, "task index") 

    with tf.device(tf.train.replica_device_setter(
      worker_device="/job:worker/task:%d" % FLAGS.task_index, 
      cluster=cluster)): 
     tf.import_graph_def(graph.as_graph_def(),return_elements=["xin","yin","varm","varl","mulop","addop","cost","optimizer"]) 

堆棧跟蹤:

Traceback (most recent call last): 
    File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1039, in _do_call 
    return fn(*args) 
    File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1017, in _run_fn 
    self._extend_graph() 
    File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1066, in _extend_graph 
    self._session, graph_def.SerializeToString(), status) 
    File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/contextlib.py", line 66, in __exit__ 
    next(self.gen) 
    File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status 
    pywrap_tensorflow.TF_GetCode(status)) 
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot colocate nodes 'import/varl/read' and 'import/varl: Cannot merge devices with incompatible jobs: '/job:ps/task:1' and '/job:worker/task:1' 
    [[Node: import/varl/read = Identity[T=DT_FLOAT, _class=["loc:@import/varl"], _device="/job:worker/task:1"](import/varl)]] 

或有任何其他方式Tensorflow允許這樣做嗎?

回答

0

自2017年6月起不支持此功能。要在分佈式環境中訓練模型,可以重新使用圖形生成python代碼(如果它包裝在replica_device_setter中),而不是生成的圖形本身。