2017-04-05 55 views
4

我想創建一個從迭代器填充的隊列。在下面的MWE然而,總是相同的值入隊:從Python迭代器填充隊列

import tensorflow as tf 
import numpy as np 

# data 
imgs = [np.random.randn(i,i) for i in [2,3,4,5]] 

# iterate through data infinitly 
def data_iterator(): 
    while True: 
     for img in imgs: 
      yield img 

it = data_iterator() 

# create queue for data 
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64]) 

# feed next element from iterator 
enqueue_op = q.enqueue(list(next(it))) 

# setup queue runner 
numberOfThreads = 1 
qr = tf.train.QueueRunner(q, [enqueue_op] * numberOfThreads) 
tf.train.add_queue_runner(qr) 

# dequeue 
dequeue_op = q.dequeue() 
dequeue_op = tf.Print(dequeue_op, data=[dequeue_op], message="dequeue()") 

# We start the session as usual ... 
with tf.Session() as sess: 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(10): 
     data = sess.run(dequeue_op) 
     print(data) 
. 
    coord.request_stop() 
    coord.join(threads) 

難道我一定要使用feed_dict?如果是的話,我該如何結合QueueRunner使用它?

回答

3

當運行

enqueue_op = q.enqueue(list(next(it))) 

tensorflow將執行清單(下一個(它))正好一次。此後,它會保存第一個列表,並在每次運行enqueue_op時將其添加到q中。爲了避免這種情況,你必須使用佔位符。提供的佔位符與tf.train.QueueRunner不兼容。請使用此:

import tensorflow as tf 
import numpy as np 
import threading 

# data 
imgs = [np.random.randn(i,i) for i in [2,3,4,5]] 

# iterate through data infinitly 
def data_iterator(): 
    while True: 
     for img in imgs: 
      yield img 

it = data_iterator() 

# create queue for data 
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64]) 

# feed next element from iterator 

img_p = tf.placeholder(tf.float64, [None, None]) 
enqueue_op = q.enqueue(img_p) 

dequeue_op = q.dequeue() 


with tf.Session() as sess: 
    coord = tf.train.Coordinator() 

    def enqueue_thread(): 
     with coord.stop_on_exception(): 
      while not coord.should_stop(): 
       sess.run(enqueue_op, feed_dict={img_p: list(next(it))}) 

    numberOfThreads = 1 
    for i in range(numberOfThreads): 
     threading.Thread(target=enqueue_thread).start() 



    for i in range(3): 
     data = sess.run(dequeue_op) 
     print(data)