1
我想在tensorflow中創建一個變量,然後在tf.scan中進行更新。首先我試過這樣的事情:如何在tensorflow掃描中更新vaiables
import tensorflow as tf
with tf.variable_scope('foo'):
tf.get_variable('bar', initializer=tf.zeros([1.0]))
def repeat_me(last, current):
with tf.variable_scope('foo', reuse=True):
bar = tf.get_variable('bar')
bar.assign_add(tf.constant([1.0]))
return bar
output = tf.scan(repeat_me, tf.range(5), initializer=tf.constant([1.0]))
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
out = sess.run(output)
print(out)
with tf.variable_scope('foo', reuse=True):
print(tf.get_variable('bar').eval())
這似乎沒有更新名稱爲'bar'的變量。
[[ 0.]
[ 0.]
[ 0.]
[ 0.]
[ 0.]]
[ 0.]
對我來說奇怪的是以下修改的「repeat_me」函數改變了行爲。
def repeat_me(last, current):
with tf.variable_scope('foo', reuse=True):
bar = tf.get_variable('bar')
b = bar.assign_add(tf.constant([1.0]))
return b
然後,腳本會這樣:
[[ 5.]
[ 5.]
[ 5.]
[ 5.]
[ 5.]]
[ 5.]
任何人都可以解釋的區別?