2017-09-23 59 views
1

而使用以下Keras Python代碼:如何在使用Keras的數據生成器時檢索標籤信息?

for x_batch,y_batch in datagen.flow_from_directory(
    directory = os.path.join(dataset_root_path,dataset_train_path), 
    target_size = (520,520), 
    class_mode = 'binary', 
    batch_size = 1 
): 

我得到了x_batch和y_batch numpy的陣列,所述y_batch numpy的陣列被編碼成數0.0或1.0,因爲我使用的「二進制」 class_mode,但是,通過這種方式,我失去了關於該樣本的真實標籤的信息,例如「貓」或「狗」。如何根據輸出'1.0'和'0.0'檢索標籤信息?

回答

0

我建議來實例化datagenerator和列車fit_generator

train_gen = datagen.flow_from_directory(...) 
model.fit_generator(train_gen, ...) 

然後,您可以訪問(以及其他屬性)的類指數與train_gen.class_indices

相關問題