2017-04-10 69 views
1

對於交叉驗證,如何保存不同訓練集和交叉驗證集的訓練歷史記錄?我認爲pickle寫入的一種附加模式會起作用,但實際上它不起作用。如果可能的話,您能否請您指導保存所有模型的方法,現在我只能使用model.save(file)保存上次訓練過的模型。如何保存Keras的訓練歷史作爲交叉驗證(循環)?

historyfile = 'history.pickle' 
f = open(historyfile,'w') 
f.close() 
ind = 0 
save = {} 
for train, test in kfold.split(input,output): 
    ind = ind+1 
    #create model 
    model = model_FCN() 
    # fit the model 
    history = model.fit(input[list(train)], output[list(train)], batch_size = 16, epochs = 100, verbose =1, validation_data =(input[list(test)],output[list(test)])) 
    #save to file 
    try: 
     f = open(historyfile,'a') ## appending mode?? 
     save['cv'+ str(ind)]= history.history 
     pickle.dump(save, f, pickle.HIGHEST_PROTOCOL) 
     f.close() 
    except Exception as e: 
     print('Unable to save data to', historyfile, ':', e) 

    scores = model.evaluate(MR_patch[list(test)], CT_patch[list(test)], verbose=0) 
    print("%s: %.2f" % (model.metrics_names[1], scores[1])) 
    cvscores.append(scores[1]) 
    print("cross validation stage: " + str(ind)) 

print("%.2f (+/- %.2f)" % (np.mean(cvscores), np.std(cvscores))) 

回答

0

爲每個時間段對某些列車後保存模型和驗證數據,可以使用Callback

例如:

from keras.callbacks import ModelCheckpoint 
import os 

output_directory = '' # here should be path to output directory  
model_checkpoint = ModelCheckpoint(os.path.join(output_directory , 'weights.{epoch:02d}-{val_loss:.2f}.hdf5')) 
model.fit(input[list(train)], 
      output[list(train)], 
      batch_size=16, 
      epochs=100, 
      verbose=1, 
      validation_data=(input[list(test)],output[list(test)]), 
      callbacks=[model_checkpoint]) 

後每個時間段你的模型將被保存在文件。這個回調如果你想保存模型中訓練的每個折,你可以簡單地在添加model.save(文件),你可以在文檔(https://keras.io/callbacks/

查找更多信息您的for循環:

model.fit(input[list(train)], 
      output[list(train)], 
      batch_size=16, 
      epochs=100, 
      verbose=1, 
      validation_data=(input[list(test)],output[list(test)])) 
model.save(os.path.join(output_directory, 'fold_{}_model.hdf5'.format(ind))) 

要保存歷史記錄: 您可以保存一次歷史記錄,而無需將其追加到每個循環中的文件。 for循環之後,您應該使用鍵(摺痕標記)和值(每個摺疊的歷史記錄)字典並保存此字典,如下所示:

f = open(historyfile, 'wb') 
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL) 
f.close() 
+0

非常感謝您耐心的回答! –