2017-08-15 204 views
1

我對tf.train.string_input_producer的工作原理有些疑問。因此,假設我將filename_list作爲輸入參數提供給string_input_producer。然後,根據文檔https://www.tensorflow.org/programmers_guide/reading_data,這將創建一個FIFOQueue,我可以在其中設置時代號,隨機播放文件名等。因此,就我而言,我有4個文件名(「db1.tfrecords」,「db2.tfrecords」...)。我使用tf.train.batch來提供網絡批次的圖像。另外,每個文件名/數據庫都包含一組人員的圖像。第二個數據庫是針對第二個人的,等等。到目前爲止,我有以下代碼:用張量流確定tf.train.string_input_producer的時代數

tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"), 
          (common + "P21_db.tfrecords")] 

filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue') 
reader = tf.TFRecordReader() 

key, serialized_example = reader.read(filename_queue) 
features = tf.parse_single_example(
    serialized_example, 
    # Defaults are not specified since both keys are required. 
    features={ 
     'height': tf.FixedLenFeature([], tf.int64), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'annotation_raw': tf.FixedLenFeature([], tf.string) 
    }) 

image = tf.decode_raw(features['image_raw'], tf.uint8) 
height = tf.cast(features['height'], tf.int32) 
width = tf.cast(features['width'], tf.int32) 

image = tf.reshape(image, [height, width, 3]) 

annotation = tf.cast(features['annotation_raw'], tf.string) 

min_after_dequeue = 100 
num_threads = 4 
capacity = min_after_dequeue + num_threads * batch_size 
label_batch, images_batch = tf.train.batch([annotation, image], 
                 shapes=[[], [112, 112, 3]], 
                 batch_size=batch_size, 
                 capacity=capacity, 
                 num_threads=num_threads) 

最後,嘗試在自動編碼器的輸出,查看了重建圖像的時候,我得到了第一個從第一數據庫中的圖像,然後我開始從觀看圖像第二個數據庫等等。

我的問題:我怎麼知道我是否在同一個時代?如果我處於理智的時代,我如何合併一批來自我擁有的所有file_names的圖像?

最後,我試圖通過如下Session內評估局部變量打印出時代的價值:

epoch_var = tf.local_variables()[0] 

然後:

with tf.Session() as sess: 
    print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y. 

任何幫助深表感謝!

+0

你可以使用'tf.python_io.tf_record_iterator'來計算記錄的數量,並給出批量的大小,你應該得到當前的epoch編號。雖然沒有得到你的第二個問題。 –

+0

@vijaym,這不是我所問的。我有'tf.train.string_input_producer',而不是'tf.python_io.tf_record_iterator'。 –

回答

0

所以我想到的是,使用tf.train.shuffle_batch_join解決了我的問題,因爲它開始洗牌不同數據集的圖像。換句話說,每個批次現在都包含來自所有數據集/文件名的圖像。這裏有一個例子:

def read_my_file_format(filename_queue): 
    reader = tf.TFRecordReader() 
    key, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(
     serialized_example, 
     # Defaults are not specified since both keys are required. 
     features={ 
      'height': tf.FixedLenFeature([], tf.int64), 
      'width': tf.FixedLenFeature([], tf.int64), 
      'image_raw': tf.FixedLenFeature([], tf.string), 
      'annotation_raw': tf.FixedLenFeature([], tf.string) 
     }) 

    # This is how we create one example, that is, extract one example from the database. 
    image = tf.decode_raw(features['image_raw'], tf.uint8) 
    # The height and the weights are used to 
    height = tf.cast(features['height'], tf.int32) 
    width = tf.cast(features['width'], tf.int32) 

    # The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the 
    # height and the weight to restore the original image back. 
    image = tf.reshape(image, [height, width, 3]) 

    annotation = tf.cast(features['annotation_raw'], tf.string) 
    return annotation, image 

def input_pipeline(filenames, batch_size, num_threads, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epoch, shuffle=False, 
                name='queue') 
    # Therefore, Note that here we have created num_threads readers to read from the filename_queue. 
    example_list = [read_my_file_format(filename_queue=filename_queue) for _ in range(num_threads)] 
    min_after_dequeue = 100 
    capacity = min_after_dequeue + num_threads * batch_size 
    label_batch, images_batch = tf.train.shuffle_batch_join(example_list, 
                  shapes=[[], [112, 112, 3]], 
                  batch_size=batch_size, 
                  capacity=capacity, 
                  min_after_dequeue=min_after_dequeue) 
    return label_batch, images_batch, example_list 

label_batch, images_batch, input_ann_img = \ 
    input_pipeline(tfrecords_filename_seq, batch_size, num_threads, num_epochs=num_epoch) 

現在這個是要打造一批讀者從FIFOQueue閱讀,每個讀者後都會有不同的解碼器。最後,在解碼圖像之後,它們將被饋送到在調用tf.train.shuffle_batch_join之後創建的另一個Queue中以向網絡饋送一批圖像。