你有一個巨大的numpy數組,位於主機內存上。您希望能夠在CPU上並行處理它並將批次發送到設備。這是使用queues的好方案。
下面是一個簡單的例子,簡單地提取numpy的陣列的隨機切片(像你一樣),讓你在Python預處理與您喜愛的工具:
import numpy as np
import tensorflow as tf
def make_batch(x, y, batch_size):
rand_index = np.random.choice(x.shape[0], size=batch_size)
x_batch, y_batch = x[rand_index], y[rand_index]
# Do all your pre-processing here
# ...
return (x_batch, y_batch)
x = np.arange(10, dtype=np.float32)
y = np.arange(10, dtype=np.int32)
batch_size = 2
tf_make_batch = tf.py_func(make_batch, [x,y,batch_size], (tf.float32, tf.int32))
queue = tf.FIFOQueue(capacity=1000, dtypes=(tf.float32, tf.int32))
enqueue_op = queue.enqueue(tf_make_batch)
inputs = queue.dequeue()
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
with tf.Session() as sess:
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
for step in range(10):
print(sess.run(inputs))
coord.request_stop()
coord.join(enqueue_threads)
它採用了FIFOQueue
,因爲隨機抽樣已經發生在make_batch
。
當然要真正受益於多線程,make_batch
應該做更多的採樣。當您將一些重要的預處理添加到管道中時,您可能會開始發現顯着差異。
其重複的問題:https://stackoverflow.com/questions/45110098/tensorflow-next-batch-function-of-np-array/45110647#45110647 –