2017-03-19 103 views
0

我有一個用Keras構建的神經網絡,我試圖訓練。輸出層有4個節點。對於我試圖解決的問題,我只想根據真實值在單個輸出節點上計算梯度。基本上,y_true看起來像這樣[0,0,2,0],其中零表示應該忽略的節點。然而,y_pred的形式是[1.2,3.2,4.5,6]。我想這樣做,只有第三個指標在mse中被考慮到。這將要求我在y_pred中清零索引0,1和3。我還沒有找到適當的方法來做到這一點。Keras/Tensorflow自定義丟失函數

下面是我試過的代碼,但是它從失效函數中返回NaN。

def custom_mse(y_true, y_pred): 

    return K.mean(K.square(tf.truediv(y_pred*y_true,y_true)-y_pred), axis=-1) 

有沒有辦法在這些張量對象上做這個簡單的操作?

回答

3

做這樣的:

[1.2,3.2,4.5,6]*[0,0,2,0] = [0,0,9,0] 
[0,0,9,0]/2 = [0,0,4.5,0] 

,然後繼續正常。

這是代碼做到這一點:

def custom_mse(y_true, y_pred): 
    return K.mean(K.square(tf.divide(tf.multiply(y_pred, y_true),tf.reduce_max(y_true))-y_true), axis=-1)