2016-12-16 61 views
0

我想實現一個輸入管道到我的模型,從TFRecords讀取二進制文件; 每個二進制文件包含一個例子(圖片,標籤,我需要的其他東西)tensorflow輸入管道:樣本被讀取不止一次

我有一個文件路徑列表的文本文件;然後:

  1. 我讀取文本文件作爲列表,我將其提供給string_input_producer()以生成一個隊列;
  2. 我喂隊列到TFRecordReader讀取串行化實例,我解碼二進制數據
  3. 我使用shuffle_batch()來安排的實施例成批
  4. 我使用批次評價我模型

問題是,事實證明,可以多次讀取同一個示例,並且可能根本不會訪問某些示例; 我將步數設置爲圖像總數除以批量大小;所以我希望在最後一步結束時所有的輸入例子都被訪問過了,但事實並非如此;相反,一些人不止一次,有些人從不(隨機);這使得我的測試評估完全unrealiable

如果任何人有什麼我做錯了一個提示,請讓我知道

簡化我的模型測試代碼的版本低於; 謝謝!

def my_input(file_list, batch_size) 

    filename = [] 
    f = open(file_list, 'r') 
    for line in f: 
     filename.append(params.TEST_RECORDS_DATA_DIR + line[:-1]) 

    filename_queue = tf.train.string_input_producer(filename) 

    reader = tf.TFRecordReader() 
    _, serialized_example = reader.read(filename_queue) 

    features = tf.parse_single_example(
     serialized_example, 
     features={ 
      'image_raw': tf.FixedLenFeature([], tf.string), 
      'label_raw': tf.FixedLenFeature([], tf.string), 
      'name': tf.FixedLenFeature([], tf.string) 
      }) 

    image = tf.decode_raw(features['image_raw'], tf.uint8) 
    image.set_shape(params.IMAGE_HEIGHT*params.IMAGE_WIDTH*3) 
    image = tf.reshape(image, (params.IMAGE_HEIGHT,params.IMAGE_WIDTH,3)) 
    image = tf.cast(image, tf.float32)/255.0 
    image = preprocess(image) 

    label = tf.decode_raw(features['label_raw'], tf.uint8) 
    label.set_shape(params.NUM_CLASSES) 

    name = features['name'] 

    images, labels, image_names = tf.train.batch([image, label, name], 
      batch_size=batch_size, num_threads=2, 
      capacity=1000 + 3 * batch_size, min_after_dequeue=1000) 

    return images, labels, image_names 


def main() 

    with tf.Graph().as_default(): 

     # call input operations 
     images, labels, image_names = my_input(file_list=params.TEST_FILE_LIST, batch_size=params.BATCH_SIZE) 

     # load a trained model and make predictions  
     prediction = infer(images, labels, image_names) 

     with tf.Session() as sess: 

      for step in range(params.N_STEPS): 
       prediction_values = sess.run([prediction]) 
       # process output 

    return 

回答

0

我的猜測是,tf.train.string_input_producer(filename)設置產生的文件名無限,如果你批次在多個實例(2)線程,它可能是一個線程已經開始處理該文件的第二時間的情況下,而另一個還沒有設法完成第一輪。讀取每個例子只有一個,使用方法:

tf.train.string_input_producer(filename, num_epochs=1) 

,並在會議開始初始化局部變量:

sess.run(tf.initialize_local_variables()) 
+0

感謝回答; – bfra

+0

我設置的線程數爲1,num_epochs = 1,我如你所說初始化局部變量;那麼我得到以下錯誤:W tensorflow/core/framework/op_kernel.cc:968]超出範圍:FIFOQueue'_1_batch/fifo_queue'被關閉並且有177個示例後的元素(請求1,當前大小爲0) – bfra

+0

恰好是要評估的例子數量的一半;有什麼概念我可能會丟失 – bfra