您可以使用tf.nn.top_k()
執行排序。這個函數返回一個元組,第二個元素是索引。由於訂單下降,其訂單必須逆轉。
def ginicTF(actual:tf.Tensor,pred:tf.Tensor):
n = int(actual.get_shape()[-1])
inds = tf.reverse(tf.nn.top_k(pred,n)[1],axis=[0]) # this is the equivalent of np.argsort
a_s = tf.gather(actual,inds) # this is the equivalent of numpy indexing
a_c = tf.cumsum(a_s)
giniSum = tf.reduce_sum(a_c)/tf.reduce_sum(a_s) - (n+1)/2.0
return giniSum/n
這裏是你可以使用驗證,這個函數返回相同的數值作爲numpy的功能ginic
代碼:
sess = tf.InteractiveSession()
ac = tf.placeholder(shape=(50,),dtype=tf.float32)
pr = tf.placeholder(shape=(50,),dtype=tf.float32)
actual = np.random.normal(size=(50,))
pred = np.random.normal(size=(50,))
print('numpy version: {:.4f}'.format(ginic(actual,pred)))
print('tensorflow version: {:.4f}'.format(ginicTF(ac,pr).eval(feed_dict={ac:actual,pr:pred})))
好吧,這看起來不錯,但傳遞給NN的損失時起作用它將返回一個錯誤 爲行:---> 14 N = INT(actual.get_shape()[ - 1]) 錯誤:類型錯誤:\ __ INT \ __返回非INT(型NoneType) 它按預期工作,如果我只是運行會話 – Ilya
我認爲這是因爲你的佔位符\ tensor對於'actual'的形狀是'(None,)',這意味着它沒有預定義的長度,因此'n'不能在圖表建設。在這種情況下,你可以做的只是傳遞'n'(數組的長度)作爲函數的附加參數,而不是計算它。 – Lior
好吧,我無法解決這個問題(嘗試給n的默認值,但這並沒有解決它)。我對這個特定的問題提出了一個新的問題https://stackoverflow.com/questions/46674293/custom-loss-function-in-keras-how-to-deal-with-placeholders 再次,感謝您寫下功能在TF! – Ilya