2017-08-31 57 views
1

由於數據維度對於我的任務來說很大,因此32個樣本將消耗服務器中近9%的內存,其中總可用內存大約爲105G。所以我必須連續調用fit()在循環中。我也想通過連續調用fit()來儘早停止。如何讓Keras模型在不同的合適呼叫中儘早停止

但是,由於在Keras文檔中引入的回調方法僅適用於一個fit()調用。

在這種情況下,我該如何提前停止?

以下是我的代碼片段:

for sen_batch, cls_batch in train_data_gen: 

    sen_batch = np.array(sen_batch).reshape(-1, WORD_LENGTH, 50, 1) 
    cls_batch = np.array(cls_batch) 

    model.fit(x = sen_batch,y = cls_batch) 

    num_iterations += 1 

回答

3
  1. 使用fit_generator:你有發電機 - 你可以使用發電機教育訓練,而不是古典fit。此方法支持Callbacks,因此您可以使用keras.callbacks.EarlyStopping

  2. 當你不能使用fit_generator: 所以 - 首先是 - 你需要使用train_on_batch方法 - 作爲fit調用重置多模態(例如優化狀態)。

    train_on_batch方法返回一個損失值,但它不接受回調。所以你需要自己實施early stopping。你可以做到,例如像這樣:

    from six import next 
    
    patience = 4 
    best_loss = 1e6 
    rounds_without_improvement = 0 
    
    for epoch_nb in range(nb_of_epochs): 
        losses_list = list() 
        for batch in range(nb_of_batches): 
         x, y = next(train_data_gen) 
         losses_list.append(model.train_on_batch(x, y)) 
        mean_loss = sum(losses_list)/len(losses_list) 
    
        if mean_loss < best_loss: 
         best_loss = mean_loss 
         rounds_witout_improvement = 0 
        else: 
         rounds_without_improvement +=1 
    
        if rounds_without_improvement == patience: 
         break