2017-08-16 17 views
1

我已經定義了使用sequence_loss添加到eval_metric_ops定製損失

loss = tf.contrib.legacy_seq2seq.sequence_loss(logits, labels, weights) 

我希望把它添加到eval_metric_ops讓我ML引擎包我可以連續顯示tensorboard評估損失(默認我自己的損失函數只是精度)。我嘗試添加這是一個自定義eval_metric_ops

eval_metric_ops = { 
    'loss': loss # this has already been coputed for Modes.EVAL 
} 

但是,我得到的錯誤 「類型錯誤:eval_metric_ops的值必須是(metric_value,update_op)元組,給出:張量(」 sequence_loss/truediv: 0「,shape =(),dt ype = float32)for key:loss「 - 我需要做什麼才能將損失作爲eval_metric_op傳遞?我猜我目前的損失應該是度量值,但我不確定update_op應該是什麼?

回答

3

你的情況的度量函數可以使用tf.metrics使用損耗的移動平均線來實現:

def metric_fn(labels, predict): 
    loss = tf.contrib.legacy_seq2seq.sequence_loss(logits, labels, weights) 
    mean, op = tf.metrics.mean(loss) 
    return mean, op 
+0

我相信['tf.contrib.metrics.streaming_mean'(https://開頭WWW。 tensorflow.org/versions/r1.3/api_docs/python/tf/contrib/metrics/streaming_mean)函數會自動執行此操作。儘管如此,它還是有助於用戶注意的。 –

+0

@EliBixby與'tf.metrics.mean'不一樣嗎?我沒有檢查源代碼,但它具有相同的參數和返回值。 – marko