2016-04-16 76 views
22

我有形狀[None, 9, 2](其中None是批量)的張量流的輸入。張量流中的拼合批量

要執行進一步的操作(例如matmul),我需要將其轉換爲[None, 18]形狀。怎麼做?

回答

2

您可以使用動態整形在運行期間通過獲取批量維度的值,並將整個新維度集合計算爲tf.reshape。以下是將平面列表重新整形爲方形矩陣而不知道列表長度的示例。

tf.reset_default_graph() 
sess = tf.InteractiveSession("") 
a = tf.placeholder(dtype=tf.int32) 
# get [9] 
ashape = tf.shape(a) 
# slice the list from 0th to 1st position 
ashape0 = tf.slice(ashape, [0], [1]) 
# reshape list to scalar, ie from [9] to 9 
ashape0_flat = tf.reshape(ashape0,()) 
# tf.sqrt doesn't support int, so cast to float 
ashape0_flat_float = tf.to_float(ashape0_flat) 
newshape0 = tf.sqrt(ashape0_flat_float) 
# convert [3, 3] Python list into [3, 3] Tensor 
newshape = tf.pack([newshape0, newshape0]) 
# tf.reshape doesn't accept float, so convert back to int 
newshape_int = tf.to_int32(newshape) 
a_reshaped = tf.reshape(a, newshape_int) 
sess.run(a_reshaped, feed_dict={a: np.ones((9))}) 

您應該看到

array([[1, 1, 1], 
     [1, 1, 1], 
     [1, 1, 1]], dtype=int32) 
+0

我沒有看到任何方法'tf.batch'在此解決方案或Tensorflow ... – Muneeb

34

你可以用tf.reshape()輕鬆地做到這一點不知道批量大小。

x = tf.placeholder(tf.float32, shape=[None, 9,2]) 
shape = x.get_shape().as_list()  # a list: [None, 9, 2] 
dim = numpy.prod(shape[1:])   # dim = prod(9,2) = 18 
x2 = tf.reshape(x, [-1, dim])   # -1 means "all" 

在最後一行的-1意味着整個列無論batchsize是在運行什麼。你可以在tf.reshape()看到它。


更新:形狀= [無,3,無]

由於@kbrose。對於未定義多於1維的情況,我們可以使用tf.shape()tf.reduce_prod()

x = tf.placeholder(tf.float32, shape=[None, 3, None]) 
dim = tf.reduce_prod(tf.shape(x)[1:]) 
x2 = tf.reshape(x, [-1, dim]) 

tf.shape()返回一個可以在運行時評估的形狀張量。 tf.get_shape()和tf.shape()之間的區別可以看出in the doc

我也試過tf.contrib.layers.flatten()在另一個。第一種情況最簡單,但不能處理第二種情況。

+0

這種運作良好,如果你知道所有的其他尺寸的大小,但不會,如果其他方面有不明尺寸。例如。 'x = tf.placeholder(tf.float32,shape = [None,9,None])' – kbrose

+1

thanks @kbrose。我已經更新了案例的答案。 – weitang114

+0

@ weitang114太棒了! – kbrose

8
flat_inputs = tf.contrib.layers.flatten(inputs)