2017-02-28 46 views
4

我是TensorFlow的新手,我正在格式化一些數據以饋入循環神經網絡。我的數據是由一個3d張量輸入佔位符x。我想沿着第三維分裂x,爲此我已經(注意:n_timesteps相當於x沿第三維的長度):TensorFlow - 拆分和擠壓

# Split the previous 3d tensor to get a list of 'n_timesteps' 2d tensors of 
# shape (batch_size, features_dimension) 
x = tf.split (x, n_timesteps, axis = 2) 

雖然,正如我與numpy嘗試:

x = np.split (x, n_timesteps, axis = 2) 

如果x是一個三維ndarray然後np.split將返回n_timesteps陣列的列表與尺寸3,使得該第三維是單例。隨着numpy我知道我可以使用np.squeeze連同列表解析去除單維容易解決這個問題:

x = [np.squeeze(a, axis=2) for a in np.split(x, n_timesteps, axis=2)] 

但我怎麼可以做同樣的TF?

回答

0

嘗試使用Tensorflow的擠壓函數(tf.squeeze)和Tensorflow的掃描函數(tf.scan)而不是列表理解。

tf.scan(lambda a, x_i: tf.squeeze(x_i, [2]), x, initializer=tf.constant(0, shape=[n_dim0, n_dim1]))