我想在某些數據上使用CNN,但由於我的模型的輸出是[1000,1000,4000]時應該是[ 1000,4000]。在這種情況下,前1000是批量大小,而4000是我擁有的類的數量,因爲這是一個分類問題。Tensorflow CNN形狀錯誤
我想我可能需要在我的fully_connected圖層後再次使用tf.reshape()函數來獲得正確的輸出,但我不太確定該如何做到這一點。我已經嘗試tf.reshape(輸出[-1,4000]),但仍然保持其他1000內。
這裏是我的代碼:
cnn_input = tf.reshape(input, [-1, 1000, 1])
net = slim.conv2d(cnn_input, 128, [3])
net = slim.pool(net, [2], "MAX")
output = slim.fully_connected(net, num_classes, activation_fn=tf.nn.softmax)
return output
基本上,我的輸出要求是等級2的形狀,但由於某些原因,它的轉向了有3個維度。我需要輸出的形狀爲[1000,4000],批量大小爲x num_classes。
任何幫助將不勝感激。提前致謝!
順便說一下,我使用的是tf-slim庫。
編輯:在完全連接層之前,tf.flatten會爲此工作嗎?