import tensorflow as tf
sess = tf.Session()
def add_to_batch(image):
print('Adding to batch')
image_batch = tf.train.shuffle_batch([image],batch_size=5,capacity=11,min_after_dequeue=1,num_threads=1)
# Add to summary
tf.image_summary('images',image_batch)
return image_batch
def get_batch():
# Create filename queue of images to read
filenames = [('/media/jessica/Jessica/TensorFlow/Practice/unlabeled_data_%d.png' % i) for i in range(11)]
filename_queue = tf.train.string_input_producer(filenames)
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
# Read and process image
my_image = tf.image.decode_png(value)
my_image_float = tf.cast(my_image,tf.float32)
image_mean = tf.reduce_mean(my_image_float)
my_noise = tf.random_normal([96,96,3],mean=image_mean)
my_image_noisy = my_image_float + my_noise
print('Reading images')
return add_to_batch(my_image_noisy)
def main():
sess.run(tf.initialize_all_variables())
tf.train.start_queue_runners(sess=sess)
writer = tf.train.SummaryWriter('/media/jessica/Jessica/TensorFlow/Practice/summary_logs', graph_def=sess.graph_def)
merged = tf.merge_all_summaries()
images = get_batch()
summary_str = sess.run(merged)
writer.add_summary(summary_str)
我試圖建立TensorFlow一個簡單的神經網絡。我正在嘗試分批加載我的輸入圖像。現在我正在測試11個圖像和batch_size = 5的代碼。最終我將處理100000個圖像。
這段代碼是從TensorFlow的cifar10.py例子中修改的。由於某種原因,我的代碼停止(不終止,它只是掛在那裏)tf.train.shuffle_batch([image],batch_size=5,capacity=1,min_after_dequeue=1,num_threads=1)
我試過batch_size,容量,min_after_dequeue等不同的組合,但我仍然不知道什麼是錯的。
任何幫助將不勝感激!謝謝!
我編輯你的代碼來修復縮進(否則Python解釋器不會接受它)。讓我知道,如果這是不正確的! – mrry