2017-08-10 122 views

回答

1

您可以使用內置的tensorflow功能precisionrecall

recall = tf.metrics.recall(labels, predictions, **kwargs) 
precision = tf.metrics.precision(labels, predcitions, **kwargs) 
+0

正如問題所要求的那樣,這些函數不會爲每個類分別計算度量標準。 – Avi

1

我相信TF不提供這樣的功能呢。根據文檔(https://www.tensorflow.org/api_docs/python/tf/metrics/precision),它表示標籤和預測都將轉換爲bool,因此它只涉及二進制分類。也許有可能對這些例子進行熱門編碼,它會起作用嗎?但不確定這一點。

+0

同樣,這些函數不會爲每個類別單獨計算度量,正如問題所要求的那樣。如果某些類比其他類更頻繁地出現在數據中,則這些度量將由那些頻繁的類所支配。通常需要的是爲每個類別計算單獨的召回率和精度,然後通過類別平均它們以獲得總體值(類似於tf.metrics.mean_per_class_accuracy)。這些值可能與使用tf.metrics.recall和tf.metrics.precision獲得的不平衡數據不同。 – Avi

+0

其實我錯了; 'tf.metrics.mean_per_class_accuracy'做了一些不同的事情,對這個問題不是很好的參考。 – Avi

2

我相信你不能用tf.metrics.precision/recall函數做多類精度,回憶,f1。您可以使用sklearn像這樣一類三方案:

from sklearn.metrics import precision_recall_fscore_support as score 

prediction = [1,2,3,2] 
y_original = [1,2,3,3] 

precision, recall, f1 = score(y_original, prediction) 

print('precision: {}'.format(precision)) 
print('recall: {}'.format(recall)) 
print('fscore: {}'.format(fscore)) 

這將打印的精度數組,召回值,但只要你喜歡格式化。

1

下面是一個解決方案,爲我工作的n = 6類的問題。如果你有更多的類,這個解決方案可能很慢,你應該使用某種映射而不是循環。

假設您在張量labels和張量labels中的logits(或posters)行中有一個熱編碼類標籤。然後,如果n是類的數量,嘗試:

y_true = tf.argmax(labels, 1) 
y_pred = tf.argmax(logits, 1) 

recall = [0] * n 
update_op_rec = [[]] * n 

for k in range(n): 
    recall[k], update_op_rec[k] = tf.metrics.recall(
     labels=tf.equal(y_true, k), 
     predictions=tf.equal(y_pred, k) 
    ) 

注意,內部tf.metrics.recall,變量labelspredictions被設定爲布爾矢量像在2變量的情況下,其允許使用的功能。