0
我在TF中有一個分類模型,並且可以獲得下一個類(preds)的概率列表。現在我想選擇最高的元素(argmax)並且顯示其類標籤。如何從TensorFlow預測中獲取類標籤
這可能看起來很愚蠢,但我怎樣才能得到匹配預測張量中位置的類標籤?
feed_dict={g['x']: current_char}
preds, state = sess.run([g['preds'],g['final_state']], feed_dict)
prediction = tf.argmax(preds, 1)
preds給我一個預測每個類的向量。當然,必須有一個簡單的方法來輸出最可能的類(標籤)?
我的模型中的一些信息:
x = tf.placeholder(tf.int32, [None, num_steps], name='input_placeholder')
y = tf.placeholder(tf.int32, [None, 1], name='labels_placeholder')
batch_size = batch_size = tf.shape(x)[0]
x_one_hot = tf.one_hot(x, num_classes)
rnn_inputs = [tf.squeeze(i, squeeze_dims=[1]) for i in
tf.split(x_one_hot, num_steps, 1)]
tmp = tf.stack(rnn_inputs)
print(tmp.get_shape())
tmp2 = tf.transpose(tmp, perm=[1, 0, 2])
print(tmp2.get_shape())
rnn_inputs = tmp2
with tf.variable_scope('softmax'):
W = tf.get_variable('W', [state_size, num_classes])
b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))
rnn_outputs = rnn_outputs[:, num_steps - 1, :]
rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])
logits = tf.matmul(rnn_outputs, W) + b
predictions = tf.nn.softmax(logits)
謝謝,但如果我知道最大值,我仍然對哪個類標籤對應也感到困惑。 – dorien
對不起。但是這給出了預測值(例如0.8)。我怎樣才能從那裏知道.8屬於哪一類? – dorien
我也有同樣的疑問。我們稱之爲,我們得到給定類的最大概率。我們如何將這個概率轉換成標籤? –