2016-12-15 118 views
1

的準確性度量我下面的TensorFlow文件,具體example 1,其中指標如下分配中的示例:實現在TensorFlow

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) 

我在想:我該如何實現不同的準確性度量。例如MAP(平均精度)。假設我有一個函數:

import numpy as np 

def accuracyMAP(y_pred, y_real): 

    def __precision_at_k(r, k): 
     r = np.asarray(r)[:k] != 0 
     return np.mean(r) 

    temp_sorted = y_real[np.argsort(-y_pred)] 
    till = np.where(temp_sorted==1) 
    r = temp_sorted[:till[0]+1] 
    return __precision_at_k(r, len(r)) 

長的路要走通過是做出預測,並通過它的mean_precision_scorce功能:

for batch_xs in x: 
    y_pred = sess.run(y, feed_dict={x: batch_xs}) 
    acc = accuracyMAP(y_true, y_pred) 

其中兩個y_true和y_predicted轉換爲NumPy的arrays.But有沒有類似於Tensorflow的例子嗎?任何提示?

回答

1

如果您的mean_precision_score函數可以在TF張量上運行,您可以簡單地執行:acc = mean_precision_score(y_true, y)及更高版本sess.run(acc, ...)

+0

你的意思是像'得分=精度(y_true,Y); mean_accuracy = tf.reduce_mean(評分)'我得到的誤差'輸入「步幅」「StridedSlice」運算具有Int32類型不匹配參數begin''的」的int64類型。試圖使用'tf.cast',但沒有找到一種方法來通過它。 –

+0

您需要顯示「精確度」函數的代碼,以便了解錯誤來自何處。 – sygi

+0

我編輯了我的問題 –