2016-03-16 121 views
4

主要想法是將TFRecords轉換爲numpy數組。假設TFRecord存儲圖像。具體來說:如何將TFRecords轉換爲numpy數組?

  1. 讀取TFRecord文件並將每個圖像轉換爲numpy數組。
  2. 寫圖像到1.JPG,2.JPG等
  3. 與此同時,寫入文件名稱和標籤的文本文件是這樣的:
    1.jpg 2 
    2.jpg 4 
    3.jpg 5 
    

我目前使用下面的代碼:

import tensorflow as tf 
import os 

def read_and_decode(filename_queue): 
    reader = tf.TFRecordReader() 
    _, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(
     serialized_example, 
     # Defaults are not specified since both keys are required. 
     features={ 
      'image_raw': tf.FixedLenFeature([], tf.string), 
      'label': tf.FixedLenFeature([], tf.int64), 
      'height': tf.FixedLenFeature([], tf.int64), 
      'width': tf.FixedLenFeature([], tf.int64), 
      'depth': tf.FixedLenFeature([], tf.int64) 
     }) 
    image = tf.decode_raw(features['image_raw'], tf.uint8) 
    label = tf.cast(features['label'], tf.int32) 
    height = tf.cast(features['height'], tf.int32) 
    width = tf.cast(features['width'], tf.int32) 
    depth = tf.cast(features['depth'], tf.int32) 
    return image, label, height, width, depth 

with tf.Session() as sess: 
    filename_queue = tf.train.string_input_producer(["../data/svhn/svhn_train.tfrecords"]) 
    image, label, height, width, depth = read_and_decode(filename_queue) 
    image = tf.reshape(image, tf.pack([height, width, 3])) 
    image.set_shape([32,32,3]) 
    init_op = tf.initialize_all_variables() 
    sess.run(init_op) 
    print (image.eval()) 

我只是讀試圖爲初學者至少獲得一張圖片。當我運行這個代碼時,代碼就卡住了。

回答

9

糟糕,這是我的一個愚蠢的錯誤。我使用了一個string_input_producer,但忘了運行queue_runners。

with tf.Session() as sess: 
    filename_queue = tf.train.string_input_producer(["../data/svhn/svhn_train.tfrecords"]) 
    image, label, height, width, depth = read_and_decode(filename_queue) 
    image = tf.reshape(image, tf.pack([height, width, 3])) 
    image.set_shape([32,32,3]) 
    init_op = tf.initialize_all_variables() 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    for i in range(1000): 
    example, l = sess.run([image, label]) 
    print (example,l) 
    coord.request_stop() 
    coord.join(threads) 
+0

希望自己能夠給予好評於此。我到處尋找這個! –