2016-05-06 57 views
5

中工作我正在嘗試使用tf.scatter_update()更新tf.while_loop()中的tf.Variable。但是,結果是初始值而不是更新後的值。這裏是什麼,我試圖做的示例代碼:tf.scatter_update()如何在while_loop()

from __future__ import print_function 

import tensorflow as tf 

def cond(sequence_len, step): 
    return tf.less(step,sequence_len) 

def body(sequence_len, step): 

    begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0)) 
    begin = tf.scatter_update(begin,1,step,use_locking=None) 

    tf.get_variable_scope().reuse_variables() 
    return (sequence_len, step+1) 

with tf.Graph().as_default(): 

    sess = tf.Session() 
    step = tf.constant(0) 
    sequence_len = tf.constant(10) 
    _,step, = tf.while_loop(cond, 
        body, 
        [sequence_len, step], 
        parallel_iterations=10, 
        back_prop=True, 
        swap_memory=False, 
        name=None) 

    begin = tf.get_variable("begin",[3],dtype=tf.int32) 

    init = tf.initialize_all_variables() 
    sess.run(init) 

    print(sess.run([begin,step])) 

結果是:[array([0, 0, 0], dtype=int32), 10]。但是,我認爲結果應該是[0, 0, 10]。我在這裏做錯了什麼?

回答

6

這裏的問題是循環體中沒有任何內容取決於您的tf.scatter_update() op,所以它永遠不會執行。使它工作的最簡單方法是在更新中添加控制依賴於返回值:

def body(sequence_len, step): 
    begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0)) 
    begin = tf.scatter_update(begin, 1, step, use_locking=None) 
    tf.get_variable_scope().reuse_variables() 

    with tf.control_dependencies([begin]): 
     return (sequence_len, step+1) 

請注意,這個問題是不是唯一的TensorFlow到循環。如果您剛剛定義了一個名爲begintf.scatter_update() op,但對其調用sess.run()或其他依賴於它的內容,則不會發生更新。當您使用tf.while_loop()時,無法直接運行循環體中定義的操作,因此獲得副作用的最簡單方法是添加控件依賴項。

注意,最後的結果是[0, 9, 0]:每次迭代分配當前步驟begin[1],並在最後一次迭代的當前步驟的值是9(條件爲假時step == 10)。

+1

謝謝澄清。我不知道這件事。 – akshaybetala