2
我有使用RNN多類分類和這裏是我的RNN主代碼:Tensorflow混淆矩陣
def RNN(x, weights, biases):
x = tf.unstack(x, input_size, 1)
lstm_cell = rnn.BasicLSTMCell(num_unit, forget_bias=1.0, state_is_tuple=True)
stacked_lstm = rnn.MultiRNNCell([lstm_cell]*lstm_size, state_is_tuple=True)
outputs, states = tf.nn.static_rnn(stacked_lstm, x, dtype=tf.float32)
return tf.matmul(outputs[-1], weights) + biases
logits = RNN(X, weights, biases)
prediction = tf.nn.softmax(logits)
cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(cost)
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
我不得不所有輸入至6類分類和各類別是
happy = [1, 0, 0, 0, 0, 0]
angry = [0, 1, 0, 0, 0, 0]
neutral = [0, 0, 1, 0, 0, 0]
excited = [0, 0, 0, 1, 0, 0]
embarrassed = [0, 0, 0, 0, 1, 0]
sad = [0, 0, 0, 0, 0, 1]
的問題是使用tf.confusion_matrix()
函數I不能打印混淆矩陣:一個熱碼標籤作爲後續組成。
是否有任何方法使用這些標籤打印混淆矩陣?
如果不是,只有當我需要打印混淆矩陣時,如何將單熱代碼轉換爲整數索引?