我在TensorFlow中用我自己的損失函數訓練多目標神經網絡,但找不到有關批處理如何與該功能交互的文檔。配料如何與TensorFlow中的損失功能交互?
例如,我在下面片斷我的損失函數,該函數預測的張量/列表,並確保他們的絕對值的和不超過一個的:
def fitness(predictions,actual):
absTensor = tf.abs(predictions)
sumTensor = tf.reduce_sum(absTensor)
oneTensor = tf.constant(1.0)
isGTOne = tf.greater(sumTensor,oneTensor)
def norm(): return predictions/sumTensor
def unchanged(): return predictions
predictions = tf.cond(isGTOne,norm,unchanged)
etc...
但是,當我通過一批估計,我覺得這個損失函數正在使整個輸入集合歸一化到1,而不是每個單獨的集合總和爲1。 [.8,.8],[.8,.8]] - > [[.25,.25],[.25,25]]
而不是所需的
[[.8, .8],[.8]] - > [[.5,.5],[.5,.5]]
有人可以澄清我的懷疑嗎?如果這是我的功能目前的工作方式,我該如何改變?
這很完美。謝謝。文檔中是否有關於此行爲的地方?我想通讀是否確保沒有其他意外事件出現 – liqiudilk
您在尋找文檔時特別是哪些行爲? [tf.select文檔](https://www.tensorflow.org/api_docs/python/control_flow_ops/comparison_operators#select)很有用。 – suharshs