我想對張量應用過濾器並刪除不符合我的標準的值。例如,可以說我有一個張,看起來像這樣:從softmax中刪除低質量張量預測
softmax_tensor = [[ 0.05 , 0.05, 0.2, 0.7], [ 0.25 , 0.25, 0.3, 0.2 ]]
眼下,分類挑選張量argmax
預測:
predictions = [[3],[2]]
但是這ISN」這正是我想要的,因爲我放棄了有關該預測信心的信息。我寧願不做出預測,也不願做出錯誤的預測。所以,我希望做的是返回過濾張量,像這樣:
new_softmax_tensor = [[ 0.05 , 0.05, 0.2, 0.7]]
new_predictions = [[3]]
如果這是直線上升的蟒蛇,我有沒有問題:
new_softmax_tensor = []
new_predictions = []
for idx,listItem in enumerate(softmax_tensor):
# get two highest max values and see if they are far enough apart
M = max(listItem)
M2 = max(n for n in listItem if n!=M)
if M2 - M > 0.3: # just making up a criteria here
new_softmax_tensor.append(listItem)
new_predictions.append(predictions[idx])
但鑑於tensorflow上工作的張量,我不知道如何做到這一點 - 如果我這樣做,它會打破計算圖嗎?
A previous SO post建議使用tf.gather_nd,但在這種情況下,他們已經有了一個他們想要過濾的張量。我也看過tf.cond但仍不明白。我想很多其他人會從這個完全相同的解決方案中受益。
謝謝大家。
你的意思是這樣嗎? '' #Compute頂部2 SOFTMAX分數 actual_diff = tf.subtract(tf.nn.top_k(softmax_tensor中,k = 2,分類= TRUE)) #創建一個布爾張量,指出保持或之間的差不 cond_result = tf.cond(actual_diff
是的,可能是這樣的... –