2017-05-25 105 views
0

我想在某些數據上使用CNN,但由於我的模型的輸出是[1000,1000,4000]時應該是[ 1000,4000]。在這種情況下,前1000是批量大小,而4000是我擁有的類的數量,因爲這是一個分類問題。Tensorflow CNN形狀錯誤

我想我可能需要在我的fully_connected圖層後再次使用tf.reshape()函數來獲得正確的輸出,但我不太確定該如何做到這一點。我已經嘗試tf.reshape(輸出[-1,4000]),但仍然保持其他1000內。

這裏是我的代碼:

cnn_input = tf.reshape(input, [-1, 1000, 1]) 
    net = slim.conv2d(cnn_input, 128, [3]) 
    net = slim.pool(net, [2], "MAX") 
    output = slim.fully_connected(net, num_classes, activation_fn=tf.nn.softmax) 
    return output 

基本上,我的輸出要求是等級2的形狀,但由於某些原因,它的轉向了有3個維度。我需要輸出的形狀爲[1000,4000],批量大小爲x num_classes。

任何幫助將不勝感激。提前致謝!

順便說一下,我使用的是tf-slim庫。

編輯:在完全連接層之前,tf.flatten會爲此工作嗎?

回答

0

我遇到了同樣的錯誤。 documentation (line 1609)(與here鏈接)表明'fully_connected'操作應該使輸出變平,但它不會。正如你所建議的那樣,我剛剛在最後幾次完全連接的操作之前使用了slim.flatten,但我沒有具體的證據證明它可以工作。有了6個月沒有評論,我認爲其他人會比沒有好,但如果其他人有更多的見解,它將不勝感激。