2016-12-01 140 views
1

我爲輸入數據定義了一個x = tf.placeholder("float", shape=[None, 784])。之後,我需要知道x形狀的第一個值是批量大小。我提取了x.get_shape().as_list()[0]的值,但我得到了None。你能告訴我應該如何正確提取它?非常感謝!如何在Tensorflow中提取佔位符張量的形狀值?

編輯:

我現在已經使用tf.get_shape()但它會導致另一個錯誤。在我的代碼,我已經定義了一個deconv功能可按:

def deconv(X, W, b, output_shape): 
    X += b 
    return tf.nn.conv2d_transpose(X, W, output_shape, strides=[1, 1, 1, 1]) 

如果我設置batch_size以這種方式intbatch_size = 50,的deconv函數的調用效果很好如下:

W_conv2_T = tf.ones([5, 5, 32, 64]) 
pool1_tr = deconv(conv2_tr, W_conv2_T, tf.zeros([64]), [batch_size, 14, 14, 32]) 

conv2_tr的形狀是[50, 14, 14, 64]。並且由此產生的形狀pool1_tr[50, 14, 14, 32]。但是如果我設置batch_size = tf.get_shape(x)[0],conv2_tr的形狀是[None, 14, 14, 64]並且所得到的形狀pool1_tr變成[None, None, None, None]。這個bug非常奇怪。你能幫我解決這個問題嗎?提前致謝!

回答

4

對於佔位符中的行數,值None表示它在運行時可能會有所不同,因此您必須使用tf.shape(x)將形狀設置爲tf.Tensor。下面的代碼應該可以工作:

batch_size = tf.shape(x)[0] 
+0

非常感謝您的回答!不幸的是,我得到了另一個相關的奇怪的錯誤,我不知道這一點。我編輯了這個問題並添加了更多信息。你可以看看並幫助我嗎?謝謝! – southdoor

+1

它看起來像'tf.nn.conv2d_transpose()'的形狀推斷是相當薄弱的。 (它看起來像我寫的:P。)感謝您提請我們注意 - 我們將準備一個修復程序,同時您可以使用該形狀的已知尺寸調用'pool1_tr.set_shape()'。 – mrry

相關問題