2016-08-12 82 views
7

我對ML比較陌生,對TensorfFlow非常新。我花了很多時間在TensorFlow MINST教程以及https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/how_tos/reading_data上試圖弄清楚如何閱讀我自己的數據,但我感到有點困惑。在Tensorflow數據集中加載圖像

我在目錄/ images/0_Non /中有一堆圖像(.png)。我試圖將它們製作成一個TensorFlow數據集,然後我可以基本上將它作爲第一遍從MINST教程中運行。

import tensorflow as tf 

# Make a queue of file names including all the JPEG images files in the relative image directory. 
filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once("../images/0_Non/*.png")) 

image_reader = tf.WholeFileReader() 

# Read a whole file from the queue, the first returned value in the tuple is the filename which we are ignoring. 
_, image_file = image_reader.read(filename_queue) 

image = tf.image.decode_png(image_file) 

# Start a new session to show example output. 
with tf.Session() as sess: 
    # Required to get the filename matching to run. 
    tf.initialize_all_variables().run() 

    # Coordinate the loading of image files. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    # Get an image tensor and print its value. 
    image_tensor = sess.run([image]) 
    print(image_tensor) 

    # Finish off the filename queue coordinator. 
    coord.request_stop() 
    coord.join(threads) 

我在理解這裏發生了什麼時有點麻煩。所以它看起來像image是一個張量和image_tensor是一個numpy數組?

如何將我的圖像存入數據集?我也試着沿着Iris的例子來說明,這個例子是爲了讓我到這裏的CSV:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/base.py,但不知道如何讓我的工作爲我的情況,我有一堆PNG的。

謝謝!

+0

您可以使用type(image)來找出類型。你的數據集格式/組織與MNIST示例有什麼不同?你能重新使用與MNIST示例加載數據相同的代碼嗎? –

+0

嗯。 MNIST示例看起來好像數據是以.tar.gz格式發佈的?如果我只是將我的png目錄設爲.tar.gz格式,這是否可以工作? – Vincent

回答

1

最近添加的tf.data API使得它更容易做到這一點:

import tensorflow as tf 

# Make a Dataset of file names including all the PNG images files in 
# the relative image directory. 
filename_dataset = tf.data.Dataset.list_files("../images/0_Non/*.png") 

# Make a Dataset of image tensors by reading and decoding the files. 
image_dataset = filename_dataset.map(lambda x: tf.decode_png(tf.read_file(x))) 

# NOTE: You can add additional transformations, like 
# `image_dataset.batch(BATCH_SIZE)` or `image_dataset.repeat(NUM_EPOCHS)` 
# in here. 

iterator = image_dataset.make_one_shot_iterator() 
next_image = iterator.get_next() 

# Start a new session to show example output. 
with tf.Session() as sess: 

    try: 

    while True: 
     # Get an image tensor and print its value. 
     image_array = sess.run([next_image]) 
     print(image_tensor) 

    except tf.errors.OutOfRangeError: 
    # We have reached the end of `image_dataset`. 
    pass 

注意,對於訓練,你需要從什麼地方得到的標籤。 Dataset.zip()轉換是將image_dataset與來自不同來源的標籤數據集組合在一起的可能方式。