2017-06-07 189 views
0

永遠困我想通過批量培養我的模型批,因爲我無法找到任何例子,如何正確地做到這一點。就我所能做的事情而言,我的任務是在Tensorflow中逐批地訓練模型。Tensorflow:教育訓練由一批sess.run

queue=tf.FIFOQueue(capacity=50,dtypes=[tf.float32,tf.float32],shapes=[[10],[2]]) 
enqueue_op=queue.enqueue_many([X,Y]) 
dequeue_op=queue.dequeue() 

qr=tf.train.QueueRunner(queue,[enqueue_op]*2) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    X_train_batch,y_train_batch=tf.train.batch(dequeue_op,batch_size=2) 
    coord=tf.train.Coordinator() 
    enqueue_threads=qr.create_threads(sess,coord,start=True) 
    sess.run(tf.local_variables_initializer()) 
    for epoch in range(100): 
     print("inside loop1") 
     for iter in range(5): 
      print("inside loop2") 
      if coord.should_stop(): 
       break 
      batch_x,batch_y=sess.run([X_train_batch,y_train_batch]) 
      print("after sess.run") 
      print(batch_x.shape) 
      _=sess.run(optimizer,feed_dict={x_place:batch_x,y_place:batch_y}) 
     coord.request_stop() 
     coord.join(enqueue_threads) 

,輸出,

inside loop1 
inside loop2 

正如你所看到的, 這永遠困在它運行batch_x,batch_y=sess.run([X_train_batch,y_train_batch])線。 我不知道我該如何解決這個問題,或者這是逐批地訓練模型的正確方法?

+0

是輸出真的「內循環1,內循環1」或是「內循環1,內循環2」?其次,在我看來,你最後兩行縮進了一點,應該與「紀元」一致。 – Wontonimo

+0

抱歉的錯字,現在編輯,我找到了解決方案,現在編輯問題.. –

回答

1

經過幾個小時的搜索,我自己找到了解決方案。所以,我現在在下面回答我自己的問題。 的隊列由後臺線程,其創建,當你調用tf.train.start_queue_runners()如果你不調用這個方法,後臺線程將無法啓動,隊列將保持爲空,訓練運會無限期地阻塞等待輸入填補。

FIX: 就在訓練循環之前調用tf.train.start_queue_runners(sess)。 像我這樣做:

queue=tf.FIFOQueue(capacity=50,dtypes=[tf.float32,tf.float32],shapes=[[10],[2]]) 
enqueue_op=queue.enqueue_many([X,Y]) 
dequeue_op=queue.dequeue() 

qr=tf.train.QueueRunner(queue,[enqueue_op]*2) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    X_train_batch,y_train_batch=tf.train.batch(dequeue_op,batch_size=2) 
    coord=tf.train.Coordinator() 
    enqueue_threads=qr.create_threads(sess,coord,start=True) 
    tf.train.start_queue_runners(sess) 
    for epoch in range(100): 
     print("inside loop1") 
     for iter in range(5): 
      print("inside loop2") 
      if coord.should_stop(): 
       break 
      batch_x,batch_y=sess.run([X_train_batch,y_train_batch]) 
      print("after sess.run") 
      print(batch_x.shape) 
      _=sess.run(optimizer,feed_dict={x_place:batch_x,y_place:batch_y}) 
     coord.request_stop() 
     coord.join(enqueue_threads)