我想修改以下keras均方誤差損失(MSE),以便只計算稀疏損失。如何在Keras中實現稀疏均方誤差損失
def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true), axis=-1)
我的輸出y
是一個3通道圖像,其中,所述第三信道是非零在只有那些損失要計算的像素。任何想法如何修改上述計算稀疏損失?
我想修改以下keras均方誤差損失(MSE),以便只計算稀疏損失。如何在Keras中實現稀疏均方誤差損失
def mean_squared_error(y_true, y_pred): return K.mean(K.square(y_pred - y_true), axis=-1)
我的輸出y
是一個3通道圖像,其中,所述第三信道是非零在只有那些損失要計算的像素。任何想法如何修改上述計算稀疏損失?
這不是你正在尋找確切的損失,但我希望它會給你一個提示,寫你的函數:
def masked_mse(mask_value):
def f(y_true, y_pred):
mask_true = K.cast(K.not_equal(y_true, mask_value), K.floatx())
masked_squared_error = K.square(mask_true * (y_true - y_pred))
masked_mse = (K.sum(masked_squared_error, axis=-1)/
K.sum(mask_true, axis=-1))
return masked_mse
f.__name__ = 'Masked MSE (mask_value={})'.format(mask_value)
return f
的函數計算在預測輸出的所有值的MSE損失,除了那些在真實輸出中的相應值等於掩蔽值(例如-1)的元素。
有兩點需要注意: - 計算平均值的分母必須是非屏蔽值的數量,而不是陣列的 尺寸時,這就是爲什麼我不使用K.mean(masked_squared_error, axis=1)
,我 而不是手動平均。 - 掩碼值必須是有效的數字(即np.nan
或np.inf
不會執行此作業),這意味着您必須調整數據以使其不包含mask_value
。
在此示例中,目標輸出始終爲[1, 1, 1, 1]
,但某些預測值會逐漸被屏蔽。
y_pred = K.constant([[ 1, 1, 1, 1],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3],
[ 1, 1, 1, 3]])
y_true = K.constant([[ 1, 1, 1, 1],
[ 1, 1, 1, 1],
[-1, 1, 1, 1],
[-1,-1, 1, 1],
[-1,-1,-1, 1],
[-1,-1,-1,-1]])
true = K.eval(y_true)
pred = K.eval(y_pred)
loss = K.eval(masked_mse(-1)(y_true, y_pred))
for i in range(true.shape[0]):
print(true[i], pred[i], loss[i], sep='\t')
預期輸出是:
[ 1. 1. 1. 1.] [ 1. 1. 1. 1.] 0.0
[ 1. 1. 1. 1.] [ 1. 1. 1. 3.] 1.0
[-1. 1. 1. 1.] [ 1. 1. 1. 3.] 1.33333
[-1. -1. 1. 1.] [ 1. 1. 1. 3.] 2.0
[-1. -1. -1. 1.] [ 1. 1. 1. 3.] 4.0
[-1. -1. -1. -1.] [ 1. 1. 1. 3.] nan