2016-11-18 42 views
2

我使用tf.PaddingFIFOQueuetf.contrib.data.PaddedBatchDataset來饋入不同長度的序列和dequeue_many以獲得零填充的批量。從PaddingFIFOQueue獲得動態序列長度

有沒有一些通用的方法來獲得該批次的序列長度?

我目前的解決方案是明確提供序列長度作爲隊列的附加輸入,即我有像tf.PaddingFIFOQueue(names=["data", "seq_length"], ...)。我也可以使用tf.ones_like(),但我目前的方式似乎更便宜,更簡單。但我想知道這是否是規範的/標準的方式,或者是否有其他方法。

+0

你能用一些代碼來說明你的問題嗎?爲什麼生成的張量上的'.get_shape()'不適用於你的情況? – sygi

+0

@sygi:get_shape將返回(batch,max_length,...),因爲它是零填充的。那麼我現在如何獲得每個序列的長度? – Albert

+0

你能否假設原始句子沒有結尾0? – sygi

回答

0

您可以將您的dataseq_length組合成一個元組(或列表),然後將該元組推入隊列。

import tensorflow as tf 
sess = tf.InteractiveSession() 
q = tf.PaddingFIFOQueue(capacity=10, dtypes=[tf.int32, tf.int32], shapes=[[], [None]]) 
eq1 = q.enqueue([1, [1]]) 
eq2 = q.enqueue([2, [2,3]]) 
eq3 = q.enqueue([3, [4,5,6]]) 
dq = q.dequeue() 
sess.run(eq1) 
sess.run(eq2) 
sess.run(eq3) 
sess.run(dq) # [1, array([1], dtype=int32)] 
sess.run(dq) # [2, array([2, 3], dtype=int32)] 
sess.run(dq) # [3, array([4, 5, 6], dtype=int32)] 
+0

請注意,這正是我所描述的當前解決方案。所以基本上你的回答是,我已經描述過沒有更好的方法。 – Albert