我正在將圖像讀入我的TF網絡,但我還需要關聯的標籤以及它們。TF slice_input_producer不使張量保持同步
所以我試圖按照this answer,但輸出的標籤實際上並不匹配我在每批中獲得的圖像。
我的圖像名稱格式爲dir/3.jpg
,所以我只是從圖像文件名中提取標籤。
truth_filenames_np = ...
truth_filenames_tf = tf.convert_to_tensor(truth_filenames_np)
# get the labels
labels = [f.rsplit("/", 1)[1] for f in truth_filenames_np]
labels_tf = tf.convert_to_tensor(labels)
# *** This line should make sure both input tensors are synced (from my limited understanding)
# My list is also already shuffled, so I set shuffle=False
truth_image_name, truth_label = tf.train.slice_input_producer([truth_filenames_tf, labels_tf], shuffle=False)
truth_image_value = tf.read_file(truth_image_name)
truth_image = tf.image.decode_jpeg(truth_image_value)
truth_image.set_shape([IMAGE_DIM, IMAGE_DIM, 3])
truth_image = tf.cast(truth_image, tf.float32)
truth_image = truth_image/255.0
# Another key step, where I batch them together
truth_images_batch, truth_label_batch = tf.train.batch([truth_image, truth_label], batch_size=mb_size)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(epochs):
print "Epoch ", i
X_truth_batch = truth_images_batch.eval()
X_label_batch = truth_label_batch.eval()
# Here I display all the images in this batch, and then I check which file numbers they actually are.
# BUT, the images that are displayed don't correspond with what is printed by X_label_batch!
print X_label_batch
plot_batch(X_truth_batch)
coord.request_stop()
coord.join(threads)
我做錯了什麼,或者slice_input_producer沒有真正確保它的輸入張量是同步的嗎?
旁白:
我也注意到,當我得到tf.train.batch批次,該批次中的元素,我把它原來的列表是彼此相鄰,但批訂單ISN」 t按原始順序排列。例如:如果我的數據是[「dir/1.jpg」,「dir/2.jpg」,「dir/3.jpg」,「dir/4.jpg」,「dir/5.jpg」,dir/6.jpg「],然後我可以批量處理(batch_size = 2)[」dir/3.jpg「,」dir/4.jpg「],然後批處理[」dir/1.jpg「,」dir/2.jpg「],然後是最後一個 因此,這使得很難甚至只是使用FIFO隊列作爲標籤,因爲訂單將不符合批次訂單的要求。
能否請您編輯代碼,以再現該問題的最低限度的入隊嘗試?如在,刪除所有圖像處理,看看圖像/標籤是否洗牌 - 因爲它是我們不能運行這個代碼,除非我們有文件 –