2016-12-18 85 views
0

我正在實施對U-net進行語義分割的修改。Keras中的多個輸出給出了值錯誤

我有從網絡兩個輸出:

model = Model(input=inputs, output= [conv10, dense3]) 
    model.compile(optimizer=Adam(lr=1e-5), loss=common_loss, metrics=[common_loss]) 

其中常見的損失被定義爲:

def common_loss(y_true, y_pred): 
    segmentation_loss = categorical_crossentropy(y_true[0], y_pred[0]) 
    classifiction_loss = categorical_crossentropy(y_true[1], y_pred[1]) 
    return segmentation_loss + alpha * classifiction_loss 

當運行此我得到一個值誤差爲:

File "y-net.py", line 138, in <module> 
    train_and_predict() 
File "y-net.py", line 133, in train_and_predict 
    callbacks=[model_checkpoint], validation_data=(X_val, [y_img_val, y_class_val])) 
    File "/home/gpu_users/meetshah/miniconda2/envs/check/lib/python2.7/site-packages/keras/engine/training.py", line 1124, in fit 
    callback_metrics=callback_metrics) 
    File "/home/gpu_users/meetshah/miniconda2/envs/check/lib/python2.7/site-packages/keras/engine/training.py", line 848, in _fit_loop 
    callbacks.on_batch_end(batch_index, batch_logs) 
    File "/home/gpu_users/meetshah/miniconda2/envs/check/lib/python2.7/site-packages/keras/callbacks.py", line 63, in on_batch_end 
    callback.on_batch_end(batch, logs) 
    File "/home/gpu_users/meetshah/miniconda2/envs/check/lib/python2.7/site-packages/keras/callbacks.py", line 191, in on_batch_end 
    self.progbar.update(self.seen, self.log_values) 
    File "/home/gpu_users/meetshah/miniconda2/envs/check/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 147, in update 
    if abs(avg) > 1e-3: 
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 

我的實施和整個跟蹤可以在這裏找到:

https://gist.github.com/meetshah1995/19d54270e8d1b20f814e6c1495facc6a

+0

我從'model.compile'中刪除了度量標準,它工作。顯然keras不支持多輸入指標。 –

回答

1

你可以看到如何實現多個指標與多個輸出這裏:https://github.com/EdwardTyantov/ultrasound-nerve-segmentation/blob/master/u_model.py

model.compile(optimizer=optimizer, 
       loss={'main_output': dice_coef_loss, 'aux_output': 'binary_crossentropy'}, 
       metrics={'main_output': dice_coef, 'aux_output': 'acc'}, 
       loss_weights={'main_output': 1., 'aux_output': 0.5}) 

我不確定是否支持組合輸出度量標準。

+0

多個指標在keras github回購中仍然是一個開放性問題。 –