2017-02-11 37 views
0

我在TensorFlow中用我自己的損失函數訓練多目標神經網絡,但找不到有關批處理如何與該功能交互的文檔。配料如何與TensorFlow中的損失功能交互?

例如,我在下面片斷我的損失函數,該函數預測的張量/列表,並確保他們的絕對值的和不超過一個的:

def fitness(predictions,actual): 

    absTensor = tf.abs(predictions) 
    sumTensor = tf.reduce_sum(absTensor) 
    oneTensor = tf.constant(1.0) 

    isGTOne = tf.greater(sumTensor,oneTensor) 

    def norm(): return predictions/sumTensor 
    def unchanged(): return predictions 

    predictions = tf.cond(isGTOne,norm,unchanged) 

    etc... 

但是,當我通過一批估計,我覺得這個損失函數正在使整個輸入集合歸一化到1,而不是每個單獨的集合總和爲1。 [.8,.8],[.8,.8]] - > [[.25,.25],[.25,25]]
而不是所需的
[[.8, .8],[.8]] - > [[.5,.5],[.5,.5]]

有人可以澄清我的懷疑嗎?如果這是我的功能目前的工作方式,我該如何改變?

回答

2

您必須爲減速操作指定減速軸,否則所有軸都將減小。傳統上,這是張量的第一維。所以,第2行應該看起來像這樣:

sumTensor = tf.reduce_sum(absTensor, 0) 

做出更改後,您將遇到另一個問題。 sumTensor將不再是一個標量,因此不再是有意義的tf.cond的條件(即對於批次的每個條目分支是什麼意思?)。你真正想要的是tf.select,因爲你並不是真的想爲每個批次條目分支邏輯。就像這樣:

isGTOne = tf.greater(sumTensor,oneTensor) 

norm = predictions/sumTensor 

predictions = tf.select(isGTOne,norm,predictions) 

但是,在這個現在看,我不會甚至懶得條件正常化的條目。由於您現在正在按批處理的粒度運行,因此我認爲您無法通過一次正常化批處理的條目來獲得性能。尤其是,由於分割並不是一個昂貴的副作用。也許只是這樣做:

def fitness(predictions,actual): 

    absTensor = tf.abs(predictions) 
    sumTensor = tf.reduce_sum(absTensor, 0) 

    predictions = predictions/sumTensor 

    etc... 

希望幫助!

+0

這很完美。謝謝。文檔中是否有關於此行爲的地方?我想通讀是否確保沒有其他意外事件出現 – liqiudilk

+0

您在尋找文檔時特別是哪些行爲? [tf.select文檔](https://www.tensorflow.org/api_docs/python/control_flow_ops/comparison_operators#select)很有用。 – suharshs