似乎有很多關於TensorFlow的使用的開放問題,並且一些tensorflow的開發者在這裏在stackoverflow上處於活動狀態。這是另一個問題。我想在其他線程中使用numpy或其他不屬於TensorFlow的線程生成訓練數據。但是,我不想重複編譯整個TensorFlow源代碼。我只是在等待另一種方式。 「tf.py_func」似乎是一種解決方法。但缺失的示例|預取和預處理數據使用線程
這與[如何對預取數據使用-一個定製的Python功能功能於tensorflow] [1]
這裏是我的MnWE(minmal - 不工作 - 實例):
更新(現在有一個輸出,但一個競爭條件,太):
import numpy as np
import tensorflow as tf
import threading
import os
import glob
import random
import matplotlib.pyplot as plt
IMAGE_ROOT = "/graphics/projects/data/mscoco2014/data/images/"
files = ["train/COCO_train2014_000000178763.jpg",
"train/COCO_train2014_000000543841.jpg",
"train/COCO_train2014_000000364433.jpg",
"train/COCO_train2014_000000091123.jpg",
"train/COCO_train2014_000000498916.jpg",
"train/COCO_train2014_000000429865.jpg",
"train/COCO_train2014_000000400199.jpg",
"train/COCO_train2014_000000230367.jpg",
"train/COCO_train2014_000000281214.jpg",
"train/COCO_train2014_000000041920.jpg"];
# --------------------------------------------------------------------------------
def pre_process(data):
"""Pre-process image with arbitrary functions
does not only use tf.functions, but arbitrary
"""
# here is the place to do some fancy stuff
# which might be out of the scope of tf
return data[0:81,0,0].flatten()
def populate_queue(sess, thread_pool, qData_enqueue_op):
"""Put stuff into the data queue
is responsible such that there is alwaays data to process
for tensorflow
"""
# until somebody tell me I can stop ...
while not thread_pool.should_stop():
# get a random image from MS COCO
idx = random.randint(0,len(files))-1
data = np.array(plt.imread(os.path.join(IMAGE_ROOT,files[idx])))
data = pre_process(data)
# put into the queue
sess.run(qData_enqueue_op, feed_dict={data_input: data})
# a simple queue for gather data (just to keep it currently simple)
qData = tf.FIFOQueue(100, [tf.float32], shapes=[[9,9]])
data_input = tf.placeholder(tf.float32)
qData_enqueue_op = qData.enqueue([tf.reshape(data_input,[9,9])])
qData_dequeue_op = qData.dequeue()
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
# init all variables
sess.run(init_op)
# coordinate of pool of threads
thread_pool = tf.train.Coordinator()
# start fill in data
t = threading.Thread(target=populate_queue, args=(sess, thread_pool, qData_enqueue_op))
t.start()
# Can I use "tf.train.start_queue_runners" here
# How to use multiple threads?
try:
while not thread_pool.should_stop():
print "iter"
# HERE THE SILENCE BEGIN !!!!!!!!!!!
batch = sess.run([qData_dequeue_op])
print batch
except tf.errors.OutOfRangeError:
print('Done training -- no more data')
finally:
# When done, ask the threads to stop.
thread_pool.request_stop()
# now they should definetely stop
thread_pool.request_stop()
thread_pool.join([t])
我基本上有三個問題:
- 這段代碼有什麼問題?它遇到了無盡的損失(這是不可調試的)。請參閱線「這裏靜寂開始...」
- 如何擴展此代碼以使用更多的線程?
- 是否值得轉換爲tf.Record可以在飛行中生成的大型數據集或數據?
看着你的代碼,我猜測它不工作,因爲'data_input'沒有在'populate_queue()'中定義。當你嘗試在'feed_dict'中使用它時,Python會引發一個'NameError'並且線程將退出。 – mrry
調用'request_stop()'的'populate_queue()'線程和調用'should_stop()'的主線程之間也存在爭用條件。主線程可能會嘗試從隊列中取出第11個元素並永久掛起。解決該問題的方法是使用['qData.close()'](https://www.tensorflow.org/versions/r0.7/api_docs/python/io_ops.html#QueueBase.close)操作表示沒有更多的元素將被添加(並且任何待處理的出隊應該被取消)。 – mrry
感謝您的時間和評論。那麼類似tensorflow的代碼會是怎樣的呢?還是不打算在飛行中計算訓練數據?更確切的說:我應該如何解決比賽條件?問題是根本沒有錯誤信息。 – PatWie