使用張量流函數tf.train.shuffle_batch我們通過將tfrecord作爲隊列讀入內存並在隊列中進行混洗(如果得到正確的理解)來獲得混洗批處理。現在我有一個高度有序的tfrecords(相同標籤的圖片一起寫入)和一個非常大的數據集(約2,550,000圖片)。我想用一批隨機標籤給我的Vgg-net餵食,但它不可能和醜陋地將所有圖片讀入內存並被洗牌。有沒有解決這個問題的方法?如何從內存有限但大型數據集的tfrecords中獲取洗牌批次?
我想過,也許第一次做洗牌,然後寫他們入TFrecord,但我不能找出一種有效的方式這樣做......
我的數據保存在這樣:
這裏是我的代碼獲得TFRecords:
dst = "/Users/cory/Desktop/3_key_frame"
classes=[]
for myclass in os.listdir(dst):
if myclass.find('.DS_Store')==-1:
classes.append(myclass)
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
class_path = dst +'/' + name
#print(class_path)
for img_seq in os.listdir(class_path):
if img_seq.find('DS_Store')==-1:
seq_pos = class_path +'/' + img_seq
if os.path.isdir(seq_pos):
for img_name in os.listdir(seq_pos):
img_path = seq_pos +'/' + img_name
img = Image.open(img_path)
img = img.resize((64,64))
img_raw = img.tobytes()
#print (img,index)
example = tf.train.Example(features=tf.train.Features(feature={
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()