2016-03-01 58 views
3
import tensorflow as tf 
sess = tf.Session() 

def add_to_batch(image): 

    print('Adding to batch') 
    image_batch = tf.train.shuffle_batch([image],batch_size=5,capacity=11,min_after_dequeue=1,num_threads=1) 

    # Add to summary 
    tf.image_summary('images',image_batch) 

    return image_batch 

def get_batch(): 

    # Create filename queue of images to read 
    filenames = [('/media/jessica/Jessica/TensorFlow/Practice/unlabeled_data_%d.png' % i) for i in range(11)] 
    filename_queue = tf.train.string_input_producer(filenames) 
    reader = tf.WholeFileReader() 
    key, value = reader.read(filename_queue) 

    # Read and process image 
    my_image = tf.image.decode_png(value) 
    my_image_float = tf.cast(my_image,tf.float32) 
    image_mean = tf.reduce_mean(my_image_float) 
    my_noise = tf.random_normal([96,96,3],mean=image_mean) 
    my_image_noisy = my_image_float + my_noise 
    print('Reading images') 

    return add_to_batch(my_image_noisy) 

def main(): 

    sess.run(tf.initialize_all_variables()) 
    tf.train.start_queue_runners(sess=sess) 
    writer = tf.train.SummaryWriter('/media/jessica/Jessica/TensorFlow/Practice/summary_logs', graph_def=sess.graph_def) 
    merged = tf.merge_all_summaries() 
    images = get_batch() 
    summary_str = sess.run(merged) 
    writer.add_summary(summary_str) 

嗨,TensorFlow shuffle_batch不工作

我試圖建立TensorFlow一個簡單的神經網絡。我正在嘗試分批加載我的輸入圖像。現在我正在測試11個圖像和batch_size = 5的代碼。最終我將處理100000個圖像。

這段代碼是從TensorFlow的cifar10.py例子中修改的。由於某種原因,我的代碼停止(不終止,它只是掛在那裏)tf.train.shuffle_batch([image],batch_size=5,capacity=1,min_after_dequeue=1,num_threads=1)

我試過batch_size,容量,min_after_dequeue等不同的組合,但我仍然不知道什麼是錯的。

任何幫助將不勝感激!謝謝!

+0

我編輯你的代碼來修復縮進(否則Python解釋器不會接受它)。讓我知道,如果這是不正確的! – mrry

回答

7

看起來問題出現,因爲聲明

tf.train.start_queue_runners(sess=sess) 

...執行已創建的任何隊列跑步之前。如果您在images = get_batch()之後移動此行,您的程序應該可以正常工作。

這裏有什麼問題? tf.train.shuffle_batch()函數內部使用tf.RandomShuffleQueue來產生隨機批次。目前,將元素放入該隊列的唯一方法是運行一個調用q.enqueue()操作的步驟。爲了使這更容易,TensorFlow有一個概念"queue runners",它是在您構建圖形時隱式收集的,然後通過致電tf.train.start_queue_runners()開始。但是,調用tf.train.start_queue_runners()僅啓動在該時間點已定義的隊列運行程序,因此它必須在創建隊列運行程序的代碼之後出現

相關問題