我爲輸入數據定義了一個x = tf.placeholder("float", shape=[None, 784])
。之後,我需要知道x
形狀的第一個值是批量大小。我提取了x.get_shape().as_list()[0]
的值,但我得到了None
。你能告訴我應該如何正確提取它?非常感謝!如何在Tensorflow中提取佔位符張量的形狀值?
編輯:
我現在已經使用tf.get_shape()
但它會導致另一個錯誤。在我的代碼,我已經定義了一個deconv
功能可按:
def deconv(X, W, b, output_shape):
X += b
return tf.nn.conv2d_transpose(X, W, output_shape, strides=[1, 1, 1, 1])
如果我設置batch_size
以這種方式int
:batch_size = 50
,的deconv
函數的調用效果很好如下:
W_conv2_T = tf.ones([5, 5, 32, 64])
pool1_tr = deconv(conv2_tr, W_conv2_T, tf.zeros([64]), [batch_size, 14, 14, 32])
conv2_tr
的形狀是[50, 14, 14, 64]
。並且由此產生的形狀pool1_tr
是[50, 14, 14, 32]
。但是如果我設置batch_size = tf.get_shape(x)[0]
,conv2_tr
的形狀是[None, 14, 14, 64]
並且所得到的形狀pool1_tr
變成[None, None, None, None]
。這個bug非常奇怪。你能幫我解決這個問題嗎?提前致謝!
非常感謝您的回答!不幸的是,我得到了另一個相關的奇怪的錯誤,我不知道這一點。我編輯了這個問題並添加了更多信息。你可以看看並幫助我嗎?謝謝! – southdoor
它看起來像'tf.nn.conv2d_transpose()'的形狀推斷是相當薄弱的。 (它看起來像我寫的:P。)感謝您提請我們注意 - 我們將準備一個修復程序,同時您可以使用該形狀的已知尺寸調用'pool1_tr.set_shape()'。 – mrry