2017-03-06 69 views
0

我目前使用的是Torch 7,我需要自定義丟失函數,特別是交叉熵錯誤函數。我如何定製火炬中的損失功能?

我想添加一些交叉熵錯誤函數的參數,我找不到應修改的部分。

我看了CrossEntropyCriterion.lua,但仍然不知道方式,因爲我沒有看到這個文件中的任何方程式。

誰能告訴我方程在哪裏?或者我應該修改哪個文件?

+0

的[添加我的自定義損失函數火炬]可能的複製(http://stackoverflow.com/questions/33648796/add -my-custom-loss-function-to-torch) –

回答

0

爲了定製損失函數,您必須更改方法__init,updateOutputupdateGradInput

  • __init是類的初始化功能,當您使用:forward()方法在你的標準當您使用:backward()
  • updateGradInput會叫
  • updateOutput將被調用,它是你的標準的梯度

定製標準的結構如下所示:

local yourCriterion, parent = torch.class('nn.yourCriterion', 'nn.Criterion') 

function yourCriterion:__init(your_parameters): 
    parent.__init(self) 
    ... (you can add as many parameters as you want to your criterion 
     and give them the name your prefer) 
    self.parameters = your_parameters 


function yourCriterion:updateOutput(input) 
    ... (your criterion code here) 
    return value_of_the_criterion 

function yourCriterion:updateGradInput(input): 
    ... (your criterion gradient code here) 
    return gradient 

[編輯]:你可以在這裏找到交叉熵準則的代碼https://github.com/torch/nn/blob/master/CrossEntropyCriterion.lua

+0

對不起,我忘了把鏈接放到交叉熵標準中。我編輯了我的答案來添加它。 –

+0

感謝您的回答:)我會盡力的 –