2016-06-29 82 views
3

我想基於損失函數w.r.t的梯度值實現停止條件。重量。 例如,假設我有這樣的事情:梯度值tensorflow的停止條件

optimizer = tf.train.AdamOptimizer() 
grads_and_vars = optimizer.compute_gradients(a_loss_function) 
train_op = optimizer.apply_gradients(grads_and_vars) 

話,我想用這樣的運行圖:

for step in range(TotSteps): 
    output = sess.run([input], feed_dict=some_dict) 
    if(grad_taken_in_some_way < some_treshold): 
     print("Training finished.") 
     break 

我不知道我應該傳遞給SESS。運行()爲了得到輸出也漸變(除了我需要的所有其他東西)。我甚至不確定這是否是正確的方法,或者我應該採取不同的方式。我做了一些嘗試,但每次都失敗了。希望有人有一些提示。 提前謝謝!

編輯:英文校

EDIT2:回答Iballes正是我想做的事。不過,我不確定如何規範和總結所有的漸變。由於我在CNN中有不同的圖層,並且具有不同形狀的不同權重,如果我只是按照您的建議進行操作,則會在add_n()操作中遇到錯誤(因爲我正嘗試將不同形狀的矩陣相加)。所以可能我應該這樣做:

grad_norms = [tf.nn.l2_normalize(g[0], 0) for g in grads_and_vars]  
grad_norm = [tf.reduce_sum(grads) for grads in grad_norms] 
final_grad = tf.reduce_sum(grad_norm) 

任何人都可以證實這一點嗎?

回答

2

您的線路output = sess.run([input], feed_dict=some_dict)讓您認爲您對sess.run命令有點誤解。你所說的[input]應該是由sess.run命令提取的張量列表。因此,這是一種輸出而不是輸入。爲了解決你的問題,我們假設你正在做類似output = sess.run(loss, feed_dict=some_dict)的事情(爲了監控訓練損失)。

另外,我想你要制定你的停止標準使用標準的梯度(梯度本身是一個多維量)。因此,你想要做的是每次執行圖形時獲取漸變的規範。爲此,你必須做兩件事。 1)將梯度範數添加到計算圖。 2)在您的訓練循環中每次撥打電話sess.run取回。

廣告1)您已通過

optimizer = tf.train.AdamOptimizer() 
grads_and_vars = optimizer.compute_gradients(a_loss_function) 

添加的梯度以曲線圖,現在有保持梯度grads_and_vars(每個訓練變量圖中的張量)。讓我們把每個梯度的規範,然後總結:

grad_norms = [tf.nn.l2_loss(g) for g, v in grads_and_vars] 
grad_norm = tf.add_n(grad_norms) 

在那裏你有你的梯度規範。

廣告2)在你的循環中,通過告訴sess.run命令來獲取漸變規範以及損失。

for step in range(TotSteps): 
    l, gn = sess.run([loss, grad_norm], feed_dict=some_dict) 
    if(gn < some_treshold): 
     print("Training finished.") 
     break 
+0

對此答案的任何想法@AndreaPavone? – lballes

+0

這正是我想要做的。不過,我不確定如何規範和總結所有的漸變。由於我在CNN中有不同的圖層,並且具有不同形狀的不同權重,如果我只是按照您的建議進行操作,則會在add_n()操作中遇到錯誤(因爲我正嘗試將不同形狀的矩陣相加)。編輯問題中的更多細節。謝謝! –

+0

哦,我對張量流函數名稱感到困惑。我想要做的是''grad_norms = [tf.nn.l2_loss(g)for g,v in grads_and_vars]''。我原來的答案使用了一個函數''tf.nn.l2_norm'',它不存在,我做了**而不是**意味着''tf.nn.l2_normalize''。函數''tf.nn.l2_loss''取一個向量的平方L2範數,並因此產生一個標量,而不管該張量的尺寸如何。然後可以將這些標量加在一起來計算總體梯度的L2範數。 (更正了答案。) – lballes