我對tf.train.string_input_producer
的工作原理有些疑問。因此,假設我將filename_list作爲輸入參數提供給string_input_producer
。然後,根據文檔https://www.tensorflow.org/programmers_guide/reading_data,這將創建一個FIFOQueue
,我可以在其中設置時代號,隨機播放文件名等。因此,就我而言,我有4個文件名(「db1.tfrecords」,「db2.tfrecords」...)。我使用tf.train.batch
來提供網絡批次的圖像。另外,每個文件名/數據庫都包含一組人員的圖像。第二個數據庫是針對第二個人的,等等。到目前爲止,我有以下代碼:用張量流確定tf.train.string_input_producer的時代數
tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"),
(common + "P21_db.tfrecords")]
filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue')
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'annotation_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.reshape(image, [height, width, 3])
annotation = tf.cast(features['annotation_raw'], tf.string)
min_after_dequeue = 100
num_threads = 4
capacity = min_after_dequeue + num_threads * batch_size
label_batch, images_batch = tf.train.batch([annotation, image],
shapes=[[], [112, 112, 3]],
batch_size=batch_size,
capacity=capacity,
num_threads=num_threads)
最後,嘗試在自動編碼器的輸出,查看了重建圖像的時候,我得到了第一個從第一數據庫中的圖像,然後我開始從觀看圖像第二個數據庫等等。
我的問題:我怎麼知道我是否在同一個時代?如果我處於理智的時代,我如何合併一批來自我擁有的所有file_names的圖像?
最後,我試圖通過如下Session
內評估局部變量打印出時代的價值:
epoch_var = tf.local_variables()[0]
然後:
with tf.Session() as sess:
print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y.
任何幫助深表感謝!
你可以使用'tf.python_io.tf_record_iterator'來計算記錄的數量,並給出批量的大小,你應該得到當前的epoch編號。雖然沒有得到你的第二個問題。 –
@vijaym,這不是我所問的。我有'tf.train.string_input_producer',而不是'tf.python_io.tf_record_iterator'。 –