非常感謝@ yaroslav-bulatov指着我在這個正確的方向。
看來我最大的問題是隊列跑步者。當我用FIFOQueue
替換文件名隊列並手動排入隊列的文件名時,它可以正常工作,但因爲我也在使用shuffle_batch
隊列,所以當我試圖清空下一個目錄的文件名隊列時,我感到不安。我無法清空這個隊列,因爲它會導致鎖定或打破隊列,所以我能夠管理的最好的辦法是讓它填滿新的圖像,同時保留前一個目錄的剩菜 - 顯然沒有什麼好處!最後我用一個RandomShuffleQueue
取代了這個,並且再次以與文件名相同的方式手動排列項目。我認爲這樣可以提供足夠好的圖像混合,而且對於這個問題來說並不是過度的。沒有線程,但只要我消除了這些事情就簡單得多了。
我已經包含了我的最終解決方案,如下所示。歡迎任何建議!
import os
import tensorflow as tf
import numpy as np
from itertools import cycle
output_dir = '/my/output/dir'
my_dirs = [
[
'/path/to/datasets/blacksquares/black_square_100x100.png',
'/path/to/datasets/blacksquares/black_square_200x200.png',
'/path/to/datasets/blacksquares/black_square_300x300.png'
],
[
'/path/to/datasets/whitesquares/white_square_100x100.png',
'/path/to/datasets/whitesquares/white_square_200x200.png',
'/path/to/datasets/whitesquares/white_square_300x300.png',
'/path/to/datasets/whitesquares/white_square_400x400.png'
],
[
'/path/to/datasets/mixedsquares/black_square_200x200.png',
'/path/to/datasets/mixedsquares/white_square_200x200.png'
]
]
# set vars
patch_size = (100, 100, 1)
batch_size = 20
queue_capacity = 1000
# setup filename queue
filename_queue = tf.FIFOQueue(
capacity=queue_capacity,
dtypes=tf.string,
shapes=[[]]
)
filenames_placeholder = tf.placeholder(dtype='string', shape=(None))
filenames_enqueue_op = filename_queue.enqueue_many(filenames_placeholder)
# read file and preprocess
image_reader = tf.WholeFileReader()
key, file = image_reader.read(filename_queue)
uint8image = tf.image.decode_png(file)
cropped_image = tf.random_crop(uint8image, patch_size) # take a random 100x100 crop
float_image = tf.div(tf.cast(cropped_image, tf.float32), 255) # put pixels in the [0,1] range
# setup shuffle batch queue for training images
images_queue = tf.RandomShuffleQueue(
capacity=queue_capacity,
min_after_dequeue=0, # allow queue to become completely empty (as we need to empty it)
dtypes=tf.float32,
shapes=patch_size
)
images_enqueue_op = images_queue.enqueue(float_image)
# setup simple computation - calculate an average image patch
input = tf.placeholder(shape=(None,) + patch_size, dtype=tf.float32)
avg_image = tf.Variable(np.random.normal(loc=0.5, scale=0.5, size=patch_size).astype(np.float32))
loss = tf.nn.l2_loss(tf.sub(avg_image, input))
train_op = tf.train.AdamOptimizer(2.).minimize(loss)
# start session and initialize variables
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
# note - no need to start any queue runners as I've done away with them
for dir_index, image_paths in enumerate(my_dirs):
image_paths_cycle = cycle(image_paths)
# reset the optimisation and training vars
sess.run(tf.initialize_all_variables())
num_epochs = 1000
for i in range(num_epochs):
# keep the filename queue at capacity
size = sess.run(filename_queue.size())
image_paths = []
while size < queue_capacity:
image_paths.append(next(image_paths_cycle))
size += 1
sess.run(filenames_enqueue_op, feed_dict={filenames_placeholder: image_paths})
# keep the shuffle batch queue at capacity
size = sess.run(images_queue.size())
while size < queue_capacity:
sess.run([images_enqueue_op])
size += 1
# get next (random) batch of training images
batch = images_queue.dequeue_many(batch_size).eval()
# run train op
_, result, loss_i = sess.run([train_op, avg_image, loss], feed_dict={input: batch})
print('Iteration {:d}. Loss: {:.2f}'.format(i, loss_i))
# early stopping :)
if loss_i < 0.05:
break
# empty filename queue and verify empty
size = sess.run(filename_queue.size())
sess.run(filename_queue.dequeue_many(size))
size = sess.run(filename_queue.size())
assert size == 0
# empty batch queue and verify empty
size = sess.run(images_queue.size())
sess.run(images_queue.dequeue_many(size))
size = sess.run(filename_queue.size())
assert size == 0
# save the average image output
result_image = np.clip(result * 255, 0, 255).astype(np.uint8)
with open(os.path.join(output_dir, 'result_' + str(dir_index)), 'wb') as result_file:
result_file.write(tf.image.encode_png(result_image).eval())
print('Happy days!')
exit(0)
這OUPUTS result_0
- 一個黑色的正方形,result_1
- 白色正方形和result_2
- 一個(大部分)灰色正方形。
感謝您的建議!我試圖讓它像你說的那樣工作,但是當我嘗試獲得批處理時它會一直掛着。聽起來像這是由一個隊列進入另一個隊列(因爲'filename_queue'進入'shuffle_batch'隊列)? – lopsided
當您嘗試排入完整隊列或從空隊列中退出時,它會掛起。嘗試「sess.run([filename_queue.size()])」,看看隊列中有多少元素 –
你也需要啓動隊列跑步者,因爲tf.batch也是隊列 –