2016-10-02 41 views
0

所以我有以下numpy數組。For loop評估精度不執行

  • X驗證集,X_val:(47151,32,32,1)
  • Ý驗證集(標籤),y_val_dummy:(47151,5,10)
  • ý驗證預測套組,y_pred: (47151,5,10)

當我運行代碼時,它似乎需要永遠。有人可以建議爲什麼?我相信這是一個代碼效率問題。我似乎無法完成這個過程。

y_pred_list = model.predict(X_val) 
correct_preds = 0 
# Iterate over sample dimension 
for i in range(X_val.shape[0]):   
    pred_list_i = [y_pred_array[i] for y_pred in y_pred_array] 
    val_list_i = [y_val_dummy[i] for y_val in y_val_dummy] 
    matching_preds = [pred.argmax(-1) == val.argmax(-1) for pred, val in zip(pred_list_i, val_list_i)] 
    correct_preds = int(np.all(matching_preds)) 

total_acc = correct_preds/float(x_val.shape[0]) 
+0

不應該是'[y_pred [i] for y_pred in y_pred_array]'而不是類似的下一步? – Divakar

+0

@Divakar謝謝是的。哈哈。 – Ritchie

回答

0

你主要的問題是,你產生非常大的列表數量龐大的沒有真正的理由

for i in range(X_val.shape[0]): 
    # this line generates a 47151 x 5 x 10 array every time   
    pred_list_i = [y_pred_array[i] for y_pred in y_pred_array] 

發生了什麼事是迭代的第二numpy的數組迭代速度最慢(即最左邊的),所以每個列表理解運行在47K條目上。

稍好將

for i in range(X_val.shape[0]):   
    pred_list_i = [y_pred for y_pred in y_pred_array[i]] 
    val_list_i = [y_val for y_val in y_val_dummy[i]] 
    matching_preds = [pred.argmax(-1) == val.argmax(-1) for pred, val in zip(pred_list_i, val_list_i)] 
    correct_preds = int(np.all(matching_preds)) 

但你仍然複製了很多陣列沒有真正的目的。下面的代碼應該這樣做,沒有無用的複製。

correct_preds = 0.0 
for pred, val in zip(y_pred_array, y_val_dummy): 
    correct_preds += all(p.argmax(-1) == v.argmax(-1) 
         for p, v in zip(pred, val)) 
total_accuracy = correct_preds/x_val.shape[0] 

這假設您的準確預測準確性是準確的。 您可以完全避免顯式循環,只需撥打np.argmax即可,但您必須自行解決。