2016-03-16 90 views
1

我想用TensorFlow使用Python線程來實現異步梯度下降。在主代碼,我定義圖表,包括訓練操作,它得到一個變量來保持global_step的計數:在Tensorflow中的線程之間共享變量

with tf.variable_scope("scope_global_step") as scope_global_step: 
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 
optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate) 
train_op = optimizer.minimize(loss, global_step=global_step) 

如果我打印的global_step的名字,我得到:

scope_global_step/global_step:0

主要的代碼也可以啓動多個線程執行training方法:

threads = [threading.Thread(target=training, args=(sess, train_op, loss, scope_global_step)) for i in xrange(NUM_TRAINING_THREADS)] 
for t in threads: t.start() 

我想每個線程如果該值停止執行global_step大於或等於FLAGS.max_steps。爲此,我建立training方法,因爲它如下:

def training(sess, train_op, loss, scope_global_step): 
    while (True): 
     _, loss_value = sess.run([train_op, loss]) 
     with tf.variable_scope(scope_global_step, reuse=True): 
      global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 
      global_step = global_step.eval(session=sess) 
      if global_step >= FLAGS.max_steps: break 

這失敗消息:

ValueError: Under-sharing: Variable scope_global_step/global_step does not exist, disallowed. Did you mean to set reuse=None in VarScope?

我可以看到,:0被加到變量的名稱是首次創建時,當我嘗試檢索它時,不使用該後綴。爲什麼是這樣? 如果我手動將後綴添加到變量的名稱,當我嘗試檢索它時,它仍然聲稱該變量不存在。爲什麼TensorFlow找不到變量?不應該自動在線程間共享變量嗎?我的意思是,所有線程都在同一個會話中運行,對吧?

而且關係到我training方法了另一個問題:global_step.eval(session=sess)再次執行圖表,或者它只是獲取train_oploss操作的執行後分配到gloabl_step價值?一般來說,從Python代碼中使用變量獲取值的推薦方法是什麼?

回答

1

TL; DR:傳遞您的第一個代碼片段的培訓螺紋參數中創建的global_steptf.Variable對象,並呼籲傳入的變量sess.run(global_step)

作爲一般規則,您的訓練循環(尤其是單獨線程中的訓練循環)不應修改圖形。 tf.variable_scope()上下文管理器和tf.get_variable()可以修改圖(即使它們不總是),所以你不應該在你的訓練循環中使用它們。最安全的做法是在創建訓練線程時將global_step對象(您首先創建的對象)作爲args元組之一。然後,你可以簡單地重寫你的訓練功能:

def training(sess, train_op, loss, global_step): 
    while (True): 
     _, loss_value = sess.run([train_op, loss]) 
     current_step = sess.run(global_step) 
     if current_step >= FLAGS.max_steps: break 

爲了回答您的其他問題,運行global_step.eval(session=sess)sess.run(global_step)只取了global_step變量的當前值,並且不會重新執行您的圖形的其餘部分。這是獲取tf.Variable值以供在Python代碼中使用的推薦方式。

+0

謝謝@mrry。你的解決方案當然更清潔。然而,我仍然想知道爲什麼'tf.get_variable()'找不到變量。你能解釋爲什麼這樣嗎?謝謝! – nicolas

+0

我認爲這實際上失敗了,因爲在另一個線程中運行'tf.variable_scope()'實際上引用了另一個'tf。Graph'實例來自您最初創建變量的實例。如果你使用了'with sess.graph.as_default(),tf.variable_scope(「scope_global_step」,reuse = True):'它會,但只有當你有一個訓練線程**時。該圖不是用於寫入的線程安全的,並且輸入變量範圍會導致一些圖形內部的數據結構被更新,因此您絕對不應該這樣做:)。 – mrry

+0

再次感謝@mrry。我想我學到了一兩件重要的事情:-) – nicolas