2016-01-13 70 views
4

我試圖使用tf.Print調試語句來更好地理解來自compute_gradients()的報告梯度和變量的格式,但遇到意外問題。此次培訓程序和調試程序(gvdebug)如下:tf.Print()導致梯度錯誤

def gvdebug(g, v): 
    #g = tf.Print(g,[g],'G: ') 
    #v = tf.Print(v,[v],'V: ') 
    g2 = tf.zeros_like(g, dtype=tf.float32) 
    v2 = tf.zeros_like(v, dtype=tf.float32) 
    g2 = g 
    v2 = v 
    return g2,v2 

# Define training operation 
def training(loss, global_step, learning_rate=0.1): 
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 
    grads_and_vars = optimizer.compute_gradients(loss) 
    gv2 = [gvdebug(gv[0], gv[1]) for gv in grads_and_vars] 
    train_op = optimizer.apply_gradients(gv2, global_step=global_step) 
    return train_op 

此代碼工作正常(但不打印),但如果我去掉兩個tf.Print線gvdebug()我得到一個來自apply_gradients的錯誤消息:'TypeError:Variable must be a tf.Variable'。我認爲tf.Print只是通過張量傳遞 - 我做錯了什麼?

回答

2

TL; DR

不要試圖tf.Printgv[1]因爲它是一個tf.Variable。它就像一個指向創建gv[0]中的gradient的變量的指針。

更多信息

當你運行compute_gradients它返回的gradients列表及其相應tf.Variables

grads_and_vars的每個元素是Tensortf.Variable。重要的是要注意,它是而不是變量的值。

您的代碼在刪除後適用於我v = tf.Print(v,[v],'V: ')