2017-06-02 58 views
0

我想在Torch中創建一個自定義丟失函數,這是對ClassNLLCriterion的修改。具體而言,ClassNLLCriterion損耗是:修改火炬標準

loss(x, class) = -x[class] 

我想修改這是:

loss(x, class) = -x[class]*K 

其中K是網絡輸入的功能,而不是網絡權重或網絡輸出。因此K可以被視爲一個常數。

什麼是實現此自定義條件的最簡單的方法? updateOutput()函數看起來很簡單,但我該如何修改updateGradInput()函數?

回答

1

基本上你的損失函數L是輸入和目標的功能。所以你有

loss(input, target) = ClassNLLCriterion(input, target) * K 

如果我正確理解你的新損失。那麼要實現updateGradInput這對於返回你的損失函數的導數的輸入,這是

updateGradInput[ClassNLLCriterion](input, target) * K + ClassNLLCriterion(input, target) * dK/dinput 

因此,你只需要計算ķWRT的衍生損失函數的輸入(你沒有給我們計算K)的公式並將其插入前一行。由於您的新損失函數依賴於ClassNLLCriterion,因此您可以使用此損失函數的updateGradInputupdateOutput來計算您的損失函數。

+0

所以基本上我不必編寫自定義標準。在我的訓練碼中,我可以簡單地做: 'loss = ClassNLLCriterion:forward()* K'然後 'grad = ClassNLLCriterion:backward()* K + loss *(dK/dinput)' 這是正確的嗎? – braindead

+0

是的,這也是可能的 – fonfonx

+0

太棒了。謝謝!還有一個問題,如果K只是一個常量(不依賴於網絡參數或輸入或輸出),那麼在這種情況下你的答案會如何變化? – braindead