2017-10-12 57 views
0

我想用scikit-learn的cross_val_score()函數對我的Keras神經網絡進行交叉驗證。如何在scikit-learn的cross_val_score()中每次摺疊後運行函數?

問題是,在每次摺疊後不僅結果被記住,而且整個Keras模型。所以我想在每次摺疊後用K.clear_session()來清除這個模型。但這只是上下文的細節。

我的主要問題是:如何在scikit-learn的cross_val_score()每次摺疊後運行自定義函數?換句話說:可以運行在每次摺疊後應該運行的回調?或者還有其他解決方法?

回答

0

您可以創建一個自定義回調函數,並重新編寫此回調函數的on_train_end(self,logs = {})方法。這種新方法將在每個培訓步驟結束時完成。類似的東西:

class CustomCall(Callback): 

    def __init__(self): 
     super(CustomCall, self).__init__() 

    def on_epoch_begin(self, epoch, logs={}): 
     return 

    def on_epoch_end(self, epoch, logs={}): 
     return 

    def on_batch_begin(self, batch, logs={}): 
     return 

    def on_train_end(self, logs={}): 
     # Stuff here 
     print('\n Delete previous trained model : ') 
     K.clear_session() 
     return 
+0

不幸的是,問題是,K.clear_session()必須在評估模型後調用,而不是在cross_val_score()內部訓練之後調用。所以我必須在交叉摺疊結束時調用K.clear_session(),而不是在Keras訓練結束時。 –