2017-03-27 23 views
0

我在TF中有一個分類模型,並且可以獲得下一個類(preds)的概率列表。現在我想選擇最高的元素(argmax)並且顯示其類標籤如何從TensorFlow預測中獲取類標籤

這可能看起來很愚蠢,但我怎樣才能得到匹配預測張量中位置的類標籤?

 feed_dict={g['x']: current_char} 
     preds, state = sess.run([g['preds'],g['final_state']], feed_dict) 
     prediction = tf.argmax(preds, 1) 

preds給我一個預測每個類的向量。當然,必須有一個簡單的方法來輸出最可能的類(標籤)?

我的模型中的一些信息:

x = tf.placeholder(tf.int32, [None, num_steps], name='input_placeholder') 
y = tf.placeholder(tf.int32, [None, 1], name='labels_placeholder') 
batch_size = batch_size = tf.shape(x)[0] 
x_one_hot = tf.one_hot(x, num_classes) 
rnn_inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in 
       tf.split(x_one_hot, num_steps, 1)] 

tmp = tf.stack(rnn_inputs) 
print(tmp.get_shape()) 
tmp2 = tf.transpose(tmp, perm=[1, 0, 2]) 
print(tmp2.get_shape()) 
rnn_inputs = tmp2 


with tf.variable_scope('softmax'): 
    W = tf.get_variable('W', [state_size, num_classes]) 
    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0)) 


rnn_outputs = rnn_outputs[:, num_steps - 1, :] 
rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size]) 
y_reshaped = tf.reshape(y, [-1]) 
logits = tf.matmul(rnn_outputs, W) + b 
predictions = tf.nn.softmax(logits) 

回答

0

您可以使用tf.reduce_max()這一點。我會推薦你​​this answer。 讓我知道它是否有效 - 如果沒有,將進行編輯。

+0

謝謝,但如果我知道最大值,我仍然對哪個類標籤對應也感到困惑。 – dorien

+1

對不起。但是這給出了預測值(例如0.8)。我怎樣才能從那裏知道.8屬於哪一類? – dorien

+0

我也有同樣的疑問。我們稱之爲,我們得到給定類的最大概率。我們如何將這個概率轉換成標籤? –