該模型不出現訓練的原因是因爲輸入讀數,梯度計算和minimize()
呼叫都被定義外(且因此,在數據流而言,之前)的tf.while_loop()
的主體。這意味着在循環執行之前,模型的所有這些部分只運行一次,而循環本身不起作用。
輕微重構—移動dequeue()
操作,梯度計算,minimize()
呼叫內環路—修復問題,並允許程序訓練:
optimizer = tf.train.GradientDescentOptimizer(0.05)
def cond(i):
return i < 10
def body(i):
# Dequeue a new example each iteration.
x = q_x.dequeue()
y = q_y.dequeue()
# Compute the loss and gradient update based on the current example.
loss = (tf.add(tf.mul(x, w), b) - y)**2
train_op = optimizer.minimize(loss, global_step=gs)
# Ensure that the update is applied before continuing.
return tf.tuple([tf.add(i, 1)], control_inputs=[train_op])
loop = tf.while_loop(cond, body, [i])
UPDATE:這裏有一個完成程序執行while循環,根據您的問題中的代碼:
import tensorflow as tf
# Define a single queue with two components to store the input data.
q_data = tf.FIFOQueue(100000, [tf.float32, tf.float32])
# We will use these placeholders to enqueue input data.
placeholder_x = tf.placeholder(tf.float32, shape=[None])
placeholder_y = tf.placeholder(tf.float32, shape=[None])
enqueue_data_op = q_data.enqueue_many([placeholder_x, placeholder_y])
gs = tf.Variable(0)
w = tf.Variable(0.)
b = tf.Variable(0.)
optimizer = tf.train.GradientDescentOptimizer(0.05)
# Construct the while loop.
def cond(i):
return i < 10
def body(i):
# Dequeue a single new example each iteration.
x, y = q_data.dequeue()
# Compute the loss and gradient update based on the current example.
loss = (tf.add(tf.multiply(x, w), b) - y) ** 2
train_op = optimizer.minimize(loss, global_step=gs)
# Ensure that the update is applied before continuing.
with tf.control_dependencies([train_op]):
return i + 1
loop = tf.while_loop(cond, body, [tf.constant(0)])
data = [k * 1. for k in range(10)]
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(1):
# NOTE: Constructing the enqueue op ahead of time avoids adding
# (potentially many) copies of `data` to the graph.
sess.run(enqueue_data_op,
feed_dict={placeholder_x: data, placeholder_y: data})
print (sess.run([gs, w, b])) # Prints before-loop values.
sess.run(loop)
print (sess.run([gs, w, b])) # Prints after-loop values.
我應該在外面定義** w **和** b **嗎?所以我正在嘗試類似的東西(現在我嘗試了你提供的東西),但是我得到了錯誤*所有輸入到節點,而/ GradientDescent/update_while/w/ApplyGradientDescent必須來自同一幀。* –
我添加了完整的程序我用TensorFlow 0.10rc0運行。 (您可能需要升級;'tf.while_loop()'實現中存在各種錯誤,在前幾個版本中已修復。 – mrry
是的,我在0.9上啓動它,謝謝,更新後它工作!還有一個關於你的解決方案的問題 - 它看起來像新的優化器創建的每一步,以及如果我想使用Ftrl優化器(它有一些更新的插槽)會怎麼樣?它會像訓練過程中的一個優化器一樣工作嗎? –