中工作我正在嘗試使用tf.scatter_update()
更新tf.while_loop()
中的tf.Variable
。但是,結果是初始值而不是更新後的值。這裏是什麼,我試圖做的示例代碼:tf.scatter_update()如何在while_loop()
from __future__ import print_function
import tensorflow as tf
def cond(sequence_len, step):
return tf.less(step,sequence_len)
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin,1,step,use_locking=None)
tf.get_variable_scope().reuse_variables()
return (sequence_len, step+1)
with tf.Graph().as_default():
sess = tf.Session()
step = tf.constant(0)
sequence_len = tf.constant(10)
_,step, = tf.while_loop(cond,
body,
[sequence_len, step],
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None)
begin = tf.get_variable("begin",[3],dtype=tf.int32)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([begin,step]))
結果是:[array([0, 0, 0], dtype=int32), 10]
。但是,我認爲結果應該是[0, 0, 10]
。我在這裏做錯了什麼?
謝謝澄清。我不知道這件事。 – akshaybetala