2017-03-01 23 views
0

我想使用CNN來解決deblurring的任務,並且我有訓練數據,它是一個包含文件名的png圖像和相應文本文件的目錄。在Tensorflow中傳輸CNN模型時,如何從目錄讀取圖像作爲輸入和輸出?

由於數據太大而無法用一步添加到內存中,並且是否有任何API或某種方法使我可以讀取blury圖像作爲輸入,並將其作爲預期結果讀取到地面實況以訓練?

我花了不少時間來解決這個問題,但我糊塗了在網上API介紹後讀的API。

+0

您正在尋找這樣的:http://stackoverflow.com/a/36947632/2505209?以圖片爲例以及標籤。 – hars

回答

0

該方法並不困惑。 tensorflow提供TFrecords文件以充分利用內存。

def create_cord(): 

    writer = tf.python_io.TFRecordWriter("train.tfrecords") 
    for index in xrange(66742): 
     blur_file_name = get_file_name(index, True) 
     orig_file_name = get_file_name(index, False) 
     blur_image_path = cwd + blur_file_name 
     orig_image_path = cwd + orig_file_name 

     blur_image = Image.open(blur_image_path) 
     orig_image = Image.open(orig_image_path) 

     blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH)) 
     orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH)) 

     blur_image_raw = blur_image.tobytes() 
     orig_image_raw = orig_image.tobytes() 
     example = tf.train.Example(features=tf.train.Features(feature={ 
     "blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])), 
     'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw])) 
    })) 
    writer.write(example.SerializeToString()) 
    writer.close() 

讀取數據集:

def read_and_decode(filename): 
    filename_queue = tf.train.string_input_producer([filename]) 

    reader = tf.TFRecordReader() 
    _, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(serialized_example, 
            features={ 
             'blur_image_raw': tf.FixedLenFeature([], tf.string), 
             'orig_image_raw': tf.FixedLenFeature([], tf.string), 
            }) 

    blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8) 
    blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3]) 
    blur_img = tf.cast(blur_img, tf.float32) * (1./255) - 0.5 

    orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8) 
    orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3]) 
    orig_img = tf.cast(orig_img, tf.float32) * (1./255) - 0.5 

    return blur_img, orig_img 


if __name__ == '__main__': 

    # create_cord() 

    blur, orig = read_and_decode("train.tfrecords") 
    blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig], 
               batch_size=3, capacity=1000, 
               min_after_dequeue=100) 
    init = tf.global_variables_initializer() 
    with tf.Session() as sess: 
     sess.run(init) 
    # 啓動隊列 
     threads = tf.train.start_queue_runners(sess=sess) 
     for i in range(3): 
      v, l = sess.run([blur_batch, orig_batch]) 
      print(v.shape, l.shape) 
相關問題