2017-02-03 142 views
3

培訓我在大數據集上使用Keras(使用MagnaTagATune數據集進行音樂自動標記)。所以我嘗試使用fit_generator()函數與自定義數據生成器。但是在培訓過程中損失函數和指標的價值不會改變。看起來我的網絡根本沒有訓練。Keras:網絡不使用fit_generator()

當我使用fit()函數而不是fit_generator()時,一切正常,但我無法將整個數據集保存在內存中。

我試圖既Theano和TensorFlow後端

主代碼:

if __name__ == '__main__': 
    model = models.FCN4() 
    model.compile(optimizer='adam', 
        loss='binary_crossentropy', 
        metrics=['accuracy', 'categorical_accuracy', 'precision', 'recall']) 
    gen = mttutils.generator_v2(csv_path, melgrams_dir) 
    history = model.fit_generator(gen.generate(0,750), 
            samples_per_epoch=750, 
            nb_epoch=80, 
            validation_data=gen.generate(750,1000,False), 
            nb_val_samples=250) 
    # RESULTS SAVING 
    np.save(output_history, history.history) 
    model.save(output_model) 

類generator_v2:

genres = ['guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', 'fast', 
     'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', 'opera', 'male', 'singing', 
     'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', 'flute', 'woman', 'male vocal', 'no vocal', 
     'pop', 'soft', 'sitar', 'solo', 'man', 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice', 
     'female vocal', 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice', 'choral'] 

def __init__(self, csv_path, melgrams_dir): 

    def get_dict_vals(dictionary, keys): 
     vals = [] 
     for key in keys: 
      vals.append(dictionary[key]) 
     return vals 

    self.melgrams_dir = melgrams_dir 
    with open(csv_path, newline='') as csvfile: 
     reader = csv.DictReader(csvfile, dialect='excel-tab') 
     self.labels = [] 
     for row in reader: 
      labels_arr = np.array(get_dict_vals(
       row, self.genres)).astype(np.int) 
      labels_arr = labels_arr.reshape((1, labels_arr.shape[0])) 
      if (np.sum(labels_arr) > 0): 
       self.labels.append((row['mp3_path'], labels_arr)) 
     self.size = len(self.labels) 


def generate(self, begin, end): 
    while(1): 
     for count in range(begin, end): 
      try: 
       item = self.labels[count] 
       mels = np.load(os.path.join(
        self.melgrams_dir, item[0] + '.npy')) 
       tags = item[1] 
       yield((mels, tags)) 
      except FileNotFoundError: 
       continue 

爲了製備用於配合陣列()函數我用這個代碼:

def TEST_get_data_array(csv_path, melgrams_dir): 
    gen = generator_v2(csv_path, melgrams_dir).generate(0,100) 
    item = next(gen) 
    x = np.array(item[0]) 
    y = np.array(item[1]) 
    for i in range(0,100): 
     item = next(gen.training) 
     x = np.concatenate((x,item[0]),axis = 0) 
     y = np.concatenate((y,item[1]),axis = 0) 
    return(x,y) 

對不起,如果我的代碼風格不好。謝謝你!

UPD 1: 我試着使用return(X,y),而不是yield(X,y)但沒有任何變化。我的新發電機類的

部分:

def generate(self): 
    if((self.count < self.begin) or (self.count >= self.end)): 
     self.count = self.begin 
    item = self.labels[self.count] 
    mels = np.load(os.path.join(self.melgrams_dir, item[0] + '.npy')) 
    tags = item[1] 
    self.count = self.count + 1 
    return((mels, tags)) 

def __next__(self): # fit_generator() uses this method 
    return self.generate() 

fit_generator電話:

history = model.fit_generator(tr_gen, 
           samples_per_epoch = tr_gen.size, 
           nb_epoch = 120, 
           validation_data = val_gen, 
           nb_val_samples = val_gen.size) 

日誌:

Epoch 1/120 
10554/10554 [==============================] - 545s - loss: 1.7240 - acc: 0.8922 
Epoch 2/120 
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
Epoch 3/120 
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
Epoch 4/120 
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
... etc (loss is always 1.8922; acc is always 0.8820) 
+0

在'用於範圍(開始,結束)'之前,您可能會洗牌您的數據。 –

回答

2

我有同樣的問題,因爲你與產量的方法。所以我只是存儲了當前的索引,並用return語句返回了一個批處理。

所以我只是用return (X, y)而不是yield (X,y)它工作。我不確定這是爲什麼。如果有人能夠闡明這一點,這將是很酷的。

編輯: 您需要將生成器傳遞給該函數,而不僅僅調用該函數。類似這樣的:

model.fit_generator(gen, samples_per_epoch=750, 
            nb_epoch=80, 
            validation_data=gen, 
            nb_val_samples=250) 

Keras會在調用數據的同時調用您的__next__函數。

+0

我試過了,但沒有任何變化。請檢查我是否正確理解你(我的代碼與'return'語句是在主文章的末尾)。謝謝! – Ladislao

+0

應該像這樣傳遞發電機時工作。如果不是,你可以發佈你的錯誤信息? –

+0

是的,我將我的生成器傳遞到'fit_generator'函數中。沒有例外或錯誤。問題是在培訓過程中損失函數的價值沒有變化(我已經將日誌添加到主要職位)。它看起來像網絡不刷新它的權重。這在模型中不是一個錯誤,因爲'fit'函數(使用數組而不是生成器)可以正常工作。 – Ladislao

0

在'生成'方法中,有一個while語句。

def generate(self, begin, end): 
    while(1): # this 
     for count in range(begin, end): 
      try: 
       # something 
       yield(...) 

      except FileNotFoundError: 
       continue 

我覺得不需要這種說法,所以

def generate(self, begin, end): 
    for count in range(begin, end): 
     try: 
      # something 
      yield(...) 

     except FileNotFoundError: 
      continue 
+0

它提出了一個例外: '文件 「/usr/local/lib/python3.4/dist-packages/keras/engine/training.py」,線1528,在fit_generator STR(generator_output)) ValueError異常:輸出生成器應該是一個元組(x,y,sample_weight)或(x,y)。發現:沒有' 發電機必須是無止境的,因爲它必須在下一個時代返回相同批次的數據 – Ladislao

+0

我明白了,對不起,我感激不盡。 – hmm

相關問題