我是TensorFlow的新手,我正在格式化一些數據以饋入循環神經網絡。我的數據是由一個3d張量輸入佔位符x
。我想沿着第三維分裂x
,爲此我已經(注意:n_timesteps
相當於x
沿第三維的長度):TensorFlow - 拆分和擠壓
# Split the previous 3d tensor to get a list of 'n_timesteps' 2d tensors of
# shape (batch_size, features_dimension)
x = tf.split (x, n_timesteps, axis = 2)
雖然,正如我與numpy
嘗試:
x = np.split (x, n_timesteps, axis = 2)
如果x
是一個三維ndarray
然後np.split
將返回n_timesteps
陣列的列表與尺寸3,使得該第三維是單例。隨着numpy
我知道我可以使用np.squeeze
連同列表解析去除單維容易解決這個問題:
x = [np.squeeze(a, axis=2) for a in np.split(x, n_timesteps, axis=2)]
但我怎麼可以做同樣的TF?