2016-01-20 22 views
2

我一直試圖爲TensorFlow中的GradientDescentOptimizer的每個步驟收集漸變步驟,但是當我試圖將apply_gradients()的結果傳遞給sess.run()時,我仍然遇到TypeError。我試圖運行的代碼是:無法爲TensorFlow中的GradientDescentOptimizer收集梯度

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 

x = tf.placeholder(tf.float32,[None,784]) 
W = tf.Variable(tf.zeros([784,10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.nn.softmax(tf.matmul(x,W)+b) 
y_ = tf.placeholder(tf.float32,[None,10]) 
cross_entropy = -tf.reduce_sum(y_*log(y)) 

# note that up to this point, this example is identical to the tutorial on tensorflow.org 

gradstep = tf.train.GradientDescentOptimizer(0.01).compute_gradients(cross_entropy) 

sess = tf.Session() 
sess.run(tf.initialize_all_variables()) 
batch_x,batch_y = mnist.train.next_batch(100) 
print sess.run(gradstep, feed_dict={x:batch_x,y_:batch_y}) 

需要注意的是,如果我有print sess.run(train_step,feed_dict={x:batch_x,y_:batch_y}),其中train_step = tf.GradientDescentOptimizer(0.01).minimize(cross_entropy),錯誤是沒有提出替代的最後一行。我的困惑源於minimize調用compute_gradients與第一步完全相同的論點。有人可以解釋爲什麼發生這種行爲?

回答

11

Optimizer.compute_gradients()方法返回的(TensorVariable)對,其中,每個張量是相對於相應的變量的梯度的列表。

Session.run()需要一個Tensor對象(或可轉換爲Tensor的對象)的列表作爲其第一個參數。它不知道如何處理對列表,因此您得到一個TypeError,您嘗試運行sess.run(gradstep, ...)

正確的解決方案取決於您正在嘗試執行的操作。如果你想獲取所有的梯度值,你可以做到以下幾點:

grad_vals = sess.run([grad for grad, _ in gradstep], feed_dict={x: batch_x, y: batch_y}) 

# Then, e.g., nuild a variable name-to-gradient dictionary. 
var_to_grad = {} 
for grad_val, (_, var) in zip(grad_vals, gradstep): 
    var_to_grad[var.name] = grad_val 

如果你也想獲取的變量,可以分別執行以下語句:

sess.run([var for _, var in gradstep]) 

.. 。雖然注意—沒有進一步修改您的程序—這隻會返回每個變量的初始值。 您將不得不運行優化程序的培訓步驟(或調用Optimizer.apply_gradients())來更新變量。

1

最小化調用compute_gradients後跟apply_gradients:可能你錯過了第二步。

compute_gradients只是返回grads /變量,但不會將更新規則應用於它們。

下面是一個例子:https://github.com/tensorflow/tensorflow/blob/f2bd0fc399606d14b55f3f7d732d013f32b33dd5/tensorflow/python/training/optimizer.py#L69

+0

嗨,謝謝你的回覆。我並不想盡量減少優化器,我只是試圖在每一步打印出漸變。之所以我提出最小化計數器例子,是因爲它涉及調用compute_gradients,所以人們會認爲用相同的參數調用該函數也會產生錯誤。 –

相關問題