我認爲問題源於使用tf.contrib.data.Dataset
(支持重新初始化)和tf.train.batch_join()
(它使用TensorFlow隊列和隊列運行器,因此不支持重新初始化)。
我不完全清楚你的代碼在做什麼,但我認爲你可以實現整個管道作爲Dataset
。替換下面的代碼片段:
my_iterator = MyIterator(iterations=iterations)
dataset = ds.Dataset.from_generator(my_iterator,
output_types=my_iterator.output_types,
output_shapes=my_iterator.output_shapes)
#dataset = dataset.repeat(count=repetitions)
iterator = dataset.make_initializable_iterator()
next_elem = iterator.get_next()
#change constant to 1 or 2 or something to see that the batching is more predictable
ripple_adds = [(tf.stack((next_elem[0], next_elem[1] + constant)),)
for constant in ripple_add_coefficients]
batch = tf.train.batch_join(ripple_adds, batch_size=batch_size,
enqueue_many=False, name="sink_queue")
...的東西,如下列:
my_iterator = MyIterator(iterations=iterations)
dataset = tf.contrib.data.from_generator(my_iterator,
output_types=my_iterator.output_types,
output_shapes=my_iterator.output_shapes)
def ripple_add_map_func(x, y):
return (tf.contrib.data.Dataset.range(num_ripples)
.map(lambda r: tf.stack([x, y + r])))
dataset = dataset.flat_map(ripple_add_map_func).batch(batch_size)
iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()
請添加一個可運行的小樣本,顯示您的輸入管道,這樣我們就可以重現您的問題 – GPhilo
這應該可以使用Tensorflow存儲庫的master分支運行,這是數據集的from_iterator函數所需的。如果這個例子不適用於該版本,我可以修復它。 –