2016-03-03 27 views
4

似乎有很多關於TensorFlow的使用的開放問題,並且一些tensorflow的開發者在這裏在stackoverflow上處於活動狀態。這是另一個問題。我想在其他線程中使用numpy或其他不屬於TensorFlow的線程生成訓練數據。但是,我不想重複編譯整個TensorFlow源代碼。我只是在等待另一種方式。 「tf.py_func」似乎是一種解決方法。但缺失的示例|預取和預處理數據使用線程

這與[如何對預取數據使用-一個定製的Python功能功能於tensorflow] [1]

這裏是我的MnWE(minmal - 不工作 - 實例):

更新(現在有一個輸出,但一個競爭條件,太):

import numpy as np 
import tensorflow as tf 
import threading 
import os 
import glob 
import random 
import matplotlib.pyplot as plt 

IMAGE_ROOT = "/graphics/projects/data/mscoco2014/data/images/" 
files = ["train/COCO_train2014_000000178763.jpg", 
"train/COCO_train2014_000000543841.jpg", 
"train/COCO_train2014_000000364433.jpg", 
"train/COCO_train2014_000000091123.jpg", 
"train/COCO_train2014_000000498916.jpg", 
"train/COCO_train2014_000000429865.jpg", 
"train/COCO_train2014_000000400199.jpg", 
"train/COCO_train2014_000000230367.jpg", 
"train/COCO_train2014_000000281214.jpg", 
"train/COCO_train2014_000000041920.jpg"]; 


# -------------------------------------------------------------------------------- 

def pre_process(data): 
    """Pre-process image with arbitrary functions 
    does not only use tf.functions, but arbitrary 
    """ 
    # here is the place to do some fancy stuff 
    # which might be out of the scope of tf 
    return data[0:81,0,0].flatten() 


def populate_queue(sess, thread_pool, qData_enqueue_op): 
    """Put stuff into the data queue 
    is responsible such that there is alwaays data to process 
    for tensorflow 
    """ 
    # until somebody tell me I can stop ... 
    while not thread_pool.should_stop(): 
    # get a random image from MS COCO 
    idx = random.randint(0,len(files))-1 
    data = np.array(plt.imread(os.path.join(IMAGE_ROOT,files[idx]))) 
    data = pre_process(data) 

    # put into the queue 
    sess.run(qData_enqueue_op, feed_dict={data_input: data}) 


# a simple queue for gather data (just to keep it currently simple) 
qData   = tf.FIFOQueue(100, [tf.float32], shapes=[[9,9]]) 
data_input  = tf.placeholder(tf.float32) 
qData_enqueue_op = qData.enqueue([tf.reshape(data_input,[9,9])]) 
qData_dequeue_op = qData.dequeue() 
init_op   = tf.initialize_all_variables() 


with tf.Session() as sess: 
    # init all variables 
    sess.run(init_op) 
    # coordinate of pool of threads 
    thread_pool = tf.train.Coordinator() 
    # start fill in data 
    t = threading.Thread(target=populate_queue, args=(sess, thread_pool, qData_enqueue_op)) 
    t.start() 
    # Can I use "tf.train.start_queue_runners" here 
    # How to use multiple threads? 

    try: 
    while not thread_pool.should_stop(): 
     print "iter" 
     # HERE THE SILENCE BEGIN !!!!!!!!!!! 
     batch = sess.run([qData_dequeue_op]) 
     print batch 
    except tf.errors.OutOfRangeError: 
    print('Done training -- no more data') 
    finally: 
    # When done, ask the threads to stop. 
    thread_pool.request_stop() 

# now they should definetely stop 
thread_pool.request_stop() 
thread_pool.join([t]) 

我基本上有三個問題:

  • 這段代碼有什麼問題?它遇到了無盡的損失(這是不可調試的)。請參閱線「這裏靜寂開始...」
  • 如何擴展此代碼以使用更多的線程?
  • 是否值得轉換爲tf.Record可以在飛行中生成的大型數據集或數據?
+0

看着你的代碼,我猜測它不工作,因爲'data_input'沒有在'populate_queue()'中定義。當你嘗試在'feed_dict'中使用它時,Python會引發一個'NameError'並且線程將退出。 – mrry

+1

調用'request_stop()'的'populate_queue()'線程和調用'should_stop()'的主線程之間也存在爭用條件。主線程可能會嘗試從隊列中取出第11個元素並永久掛起。解決該問題的方法是使用['qData.close()'](https://www.tensorflow.org/versions/r0.7/api_docs/python/io_ops.html#QueueBase.close)操作表示沒有更多的元素將被添加(並且任何待處理的出隊應該被取消)。 – mrry

+0

感謝您的時間和評論。那麼類似tensorflow的代碼會是怎樣的呢?還是不打算在飛行中計算訓練數據?更確切的說:我應該如何解決比賽條件?問題是根本沒有錯誤信息。 – PatWie

回答

4

你在這一行錯誤:

t = threading.Thread(target=populate_queue, args=(sess, thread_pool, qData)) 

應該qData_enqueue_op而不是qData。否則,你的排隊操作失敗,你會被卡住試圖從大小爲0的隊列中退出,我看到這個當試圖運行你的代碼,並獲得

TypeError: Fetch argument <google3.third_party.tensorflow.python.ops.data_flow_ops.FIFOQueue object at 0x4bc1f10> of <google3.third_party.tensorflow.python.ops.data_flow_ops.FIFOQueue object at 0x4bc1f10> has invalid type <class 'google3.third_party.tensorflow.python.ops.data_flow_ops.FIFOQueue'>, must be a string or Tensor. (Can not convert a FIFOQueue into a Tensor or Operation.) 

關於其他問題:

  • 你不」因爲你沒有任何東西,所以在這個例子中需要啓動隊列跑步者。隊列運行器由輸入生成器創建,如string_input_producer,它基本上是FIFO隊列+邏輯來啓動線程。您正在通過啓動自己的線程來複制隊列運行器功能的50%,這些線程可以排隊運行。 (其他50%正在關閉隊列)
  • RE:轉換爲tf.record - Python有這個叫做Global Interpreter Lock的東西,這意味着Python代碼的兩位不能同時執行。在實踐中,大量的時間花在numpy C++代碼或IO操作(釋放GIL)的事實上可以緩解這種情況。所以我認爲這是一個檢查是否能夠使用Python預處理管道實現所需的並行性的問題。
+0

模仿隊列跑步者的原因是,目前還沒有明顯的方式讓我們說包括非tensorflow類操作,更重要的是跳過數據輸入或進行高級數據操作。我可以將此答案設置爲接受,因爲現在我可以看到輸出。你能否很快給出一個參考如何避免這些競爭條件。 – PatWie

+0

避免競爭的一個簡單方法是在排隊完成後運行'qData.close()',這樣就可以使用'OutOfRangeError'而不是永遠等待。同樣,你的邏輯設置的方式,你不能保證從隊列中讀取任何東西。假設您在開始讀取之前將所有10個條目都推送到隊列中,那麼您的讀取管道將完全跳過讀取循環。你可以完全忽略'should_stop'和'request_stop'邏輯,並且提高'intra_op_parallelism_threads'來補償將被卡住等待'enqueue'運行完成的TF線程 –

+0

我的意思是「session.run」qData.close ()節點 –