2016-07-07 44 views
1

假設輸入到網絡的規模是一個placeholder具有可變批量大小,即:獲取變量批次尺寸

x = tf.placeholder(..., shape=[None, ...]) 

是有可能得到的x形狀已經饋送後? tf.shape(x)[0]仍然返回None

回答

8

如果x具有變量批量大小,獲得實際形狀的唯一方法是使用tf.shape()運算符。該運算符在tf.Tensor中返回一個符號值,因此它可以用作其他TensorFlow操作的輸入,但要獲取該形狀的具體Python值,則需要將其傳遞到Session.run()

x = tf.placeholder(..., shape=[None, ...]) 
batch_size = tf.shape(x)[0] # Returns a scalar `tf.Tensor` 

print x.get_shape()[0] # ==> "?" 

# You can use `batch_size` as an argument to other operators. 
some_other_tensor = ... 
some_other_tensor_reshaped = tf.reshape(some_other_tensor, [batch_size, 32, 32]) 

# To get the value, however, you need to call `Session.run()`. 
sess = tf.Session() 
x_val = np.random.rand(37, 100, 100) 
batch_size_val = sess.run(batch_size, {x: x_val}) 
print x_val # ==> "37" 
1

使用x.get_shape().as_list()可以得到張量x的形狀。要獲得第一維(批量大小),您可以使用x.get_shape().as_list()[0]