2017-03-03 92 views
4

我有一個不平衡的多類數據集,我想使用fit_generatorclass_weight參數根據每個類的圖像數給這些類加權。我正在使用ImageDataGenerator.flow_from_directory從目錄加載數據集。是否可以自動從Keras的flow_from_directory中推斷出class_weight?

是否可以直接推斷ImageDataGenerator對象的class_weight參數?

+0

我不認爲這是可能的。你爲什麼不計算一次呢? –

回答

8

剛剛想出了實現這一點的方法。

from collections import Counter 
train_datagen = ImageDataGenerator() 
train_generator = train_datagen.flow_from_directory(...) 

counter = Counter(train_generator.classes)       
max_val = float(max(counter.values()))  
class_weights = {class_id : max_val/num_images for class_id, num_images in counter.items()}      

model.fit_generator(..., 
        class_weight=class_weights) 

train_generator.classes是每個圖像的類別列表。 Counter(train_generator.classes)創建每個班級中圖像數量的計數器。

請注意,這些權重可能不適合收斂,但您可以將其用作基於出現次數的其他類型權重的基礎。

這個答案的靈感來自於:https://github.com/fchollet/keras/issues/1875#issuecomment-273752868

+0

但train_generator.classes只返回一個類的列表,像一個集合,不是? –

+1

它返回每個圖像的類的列表。例如,如果我們有三個圖像,前兩個圖像來自第1類,最後一個來自第0類,'train_generator.classes'等於'[1,1,0]'。 –

+1

的確,剛去看源代碼:)幹得好 –

相關問題