2017-08-12 24 views
0

使用張量流函數tf.train.shuffle_batch我們通過將tfrecord作爲隊列讀入內存並在隊列中進行混洗(如果得到正確的理解)來獲得混洗批處理。現在我有一個高度有序的tfrecords(相同標籤的圖片一起寫入)和一個非常大的數據集(約2,550,000圖片)。我想用一批隨機標籤給我的Vgg-net餵食,但它不可能和醜陋地將所有圖片讀入內存並被洗牌。有沒有解決這個問題的方法?如何從內存有限但大型數據集的tfrecords中獲取洗牌批次?

我想過,也許第一次做洗牌,然後寫他們入TFrecord,但我不能找出一種有效的方式這樣做......

我的數據保存在這樣:

enter image description here

這裏是我的代碼獲得TFRecords:

dst = "/Users/cory/Desktop/3_key_frame" 

classes=[] 
for myclass in os.listdir(dst): 
    if myclass.find('.DS_Store')==-1: 
     classes.append(myclass) 


writer = tf.python_io.TFRecordWriter("train.tfrecords") 
for index, name in enumerate(classes): 
    class_path = dst +'/' + name 
    #print(class_path) 
    for img_seq in os.listdir(class_path): 
     if img_seq.find('DS_Store')==-1: 
      seq_pos = class_path +'/' + img_seq 
      if os.path.isdir(seq_pos): 
       for img_name in os.listdir(seq_pos): 
        img_path = seq_pos +'/' + img_name 
        img = Image.open(img_path) 
        img = img.resize((64,64)) 
        img_raw = img.tobytes() 
        #print (img,index) 
        example = tf.train.Example(features=tf.train.Features(feature={ 
         "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 
         'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) 
         })) 
        writer.write(example.SerializeToString()) 
writer.close() 

回答

0

假設您的數據存儲這樣的:

/path/to/images/LABEL_1/image001.jpg 
/path/to/images/LABEL_1/image002.jpg 
... 
/path/to/images/LABEL_10/image001.jpg 

獲取在一個平面列表中的所有文件名和洗牌他們:

import glob 
import random 
filenames = glob.glob('/path/to/images/**/*.jpg) 
random.shuffle(filenames) 

創建字典從標籤名稱去數字標籤:

class_to_index = {'LABEL_1':0, 'LABEL_2': 1} # more classes I assume... 

現在,您可以遍歷所有圖像和檢索標籤

writer = tf.python_io.TFRecordWriter("train.tfrecords") 
for f in filenames: 
    img = Image.open(f) 
    img = img.resize((64,64)) 
    img_raw = img.tobytes() 
    label = f.split('/')[-2] 
    example = tf.train.Example(features=tf.train.Features(feature={ 
        "label":tf.train.Feature(int64_list=tf.train.Int64List(value= class_to_index[label])), 
        'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) 
        })) 
       writer.write(example.SerializeToString()) 
writer.close() 

希望這有助於:)

1

我假設你有已知的標籤數據集的文件名和/或結構列表。 可能值得每次在每個類的基礎上迭代通過它們,每次取N量。本質上是交錯數據集,以便不存在順序問題。 如果我正確理解這一點,那麼您主要關心的是從TFRecord抽樣數據集時,您的數據的子集可能完全包含1個類,而不是一個好的表示?

如果其結構爲:

0 0 0 0 1 1 1 1 2 2 2 2 0 0 0 0 1 1 1 1 2 2 2 2 ... etc 

這可能使shuffle_batch更容易創建培訓更好樣品。

這是我遵循的解決方案,因爲似乎沒有附加的混洗參數,您可以指定保持集合中類標籤的均勻分佈。

相關問題