2017-07-13 36 views
1

我想在tensorflow中創建一個變量,然後在tf.scan中進行更新。首先我試過這樣的事情:如何在tensorflow掃描中更新vaiables

import tensorflow as tf 

with tf.variable_scope('foo'): 
    tf.get_variable('bar', initializer=tf.zeros([1.0])) 

def repeat_me(last, current): 
    with tf.variable_scope('foo', reuse=True): 
     bar = tf.get_variable('bar') 
     bar.assign_add(tf.constant([1.0])) 
    return bar 
output = tf.scan(repeat_me, tf.range(5), initializer=tf.constant([1.0])) 

with tf.Session() as sess: 
    init_op = tf.global_variables_initializer() 
    sess.run(init_op) 
    out = sess.run(output) 
    print(out) 
    with tf.variable_scope('foo', reuse=True): 
     print(tf.get_variable('bar').eval()) 

這似乎沒有更新名稱爲'bar'的變量。

[[ 0.] 
[ 0.] 
[ 0.] 
[ 0.] 
[ 0.]] 
[ 0.] 

對我來說奇怪的是以下修改的「repeat_me」函數改變了行爲。

def repeat_me(last, current): 
    with tf.variable_scope('foo', reuse=True): 
     bar = tf.get_variable('bar') 
     b = bar.assign_add(tf.constant([1.0])) 
    return b 

然後,腳本會這樣:

[[ 5.] 
[ 5.] 
[ 5.] 
[ 5.] 
[ 5.]] 
[ 5.] 

任何人都可以解釋的區別?

回答

0

document說:

返回:

等同於 「參考」。作爲在變量更新之後想要使用新值的操作的便利返回。

bar.assign_add(tf.constant([1.0]))返回一個操作,並且您需要捕獲它以便掃描功能可以使用它。在您的第一個版本中,操作不會被保存或返回,因此您從repeat_me獲得的內容仍然是原始條形碼。

換句話說,這也將工作:

with tf.variable_scope('foo', reuse=True): 
    bar = tf.get_variable('bar') 
    return bar.assign_add(tf.constant([1.0]))