2017-06-01 24 views
0

我是tensorflow的新手,我有張量(字符串類型),其中存儲了我想用於訓練模型的所有必需圖像的圖像路徑。如何從tensorflow中的字符串張量讀取數據集名稱

問題:如何讀取張量進行排隊然後進行批處理。

我的做法是:是給我的錯誤

img_names = dataset['f0'] 
    file_length = len(img_names) 
    type(img_names) 
    tf_img_names = tf.stack(img_names) 
    filename_queue = tf.train.string_input_producer(tf_img_names, num_epochs=num_epochs, shuffle=False) 
    wd=getcwd() 
    print('In input pipeline') 
    tf_img_queue = tf.FIFOQueue(file_length,dtypes=[tf.string]) 
    col_Image = tf_img_queue.dequeue(filename_queue) 
    ### Read Image 
    img_file = tf.read_file(wd+'/'+col_Image) 
    image = tf.image.decode_png(img_file, channels=num_channels) 
    image = tf.cast(image, tf.float32)/255. 
    image = tf.image.resize_images(image,[image_width, image_height]) 
    min_after_dequeue = 100 
    capacity = min_after_dequeue + 3 * batch_size 
    image_batch, label_batch = tf.train.batch([image, onehot], batch_size=batch_size, capacity=capacity, allow_smaller_final_batch = True, min_after_dequeue=min_after_dequeue) 

錯誤:類型錯誤:預期的字符串或緩衝區」

我不知道如果我的做法是正確與否

回答

0

你不必創建另一個隊列。您可以定義一個讀取器,爲您解除出列元素。你可以嘗試下面的內容並評論一下。

reader = tf.IdentityReader() 
key, value = reader.read(filename_queue) 
dir = tf.constant(wd) 
path = tf.string_join([dir,tf.constant("/"),value]) 
img_file = tf.read_file(path) 

,並檢查你喂正確的路徑,做

print(sess.run(img_file)) 

尋找您的反饋意見。

+0

如果我已閱讀使用numpy的'genfromtxt' csv文件就像 'csv_file = np.genfromtxt(args.dataset,分隔符= '',skip_header = 1,usecols =(0,1,2,3 ,4,5),dtype = None) 如何使用string_input_producer排隊和批處理文件.... –

+0

您不必使用numpy閱讀。使用tf.TextLineReader()並將每行解析爲所需內容並加載圖像。看看這個:https://stackoverflow.com/questions/37091899/how-to-actually-read-csv-data-in-tensorflow – hars

相關問題