2016-04-26 84 views
4

我使用Keras來預測時間序列。作爲標準,我使用20個時代。我想知道我的神經網絡爲20個時代的每一個預測了什麼。Python/Keras - 訪問ModelCheckpoint回調

通過使用model.predict我得到最後的預測。不過,我希望所有的預測,或者至少最後10個(有可接受的錯誤級別)。

要訪問我正在嘗試從Keras的ModelCheckpoint函數,但我有麻煩後訪問它。我使用下面的代碼:

model=Sequential() 

model.add(GRU(input_dim=col,init='uniform',output_dim=20)) 
model.add(Dense(10)) 
model.add(Dense(5)) 
model.add(Activation("softmax")) 
model.add(Dense(1)) 

model.compile(loss="mae", optimizer="RMSprop") 

checkpoint=ModelCheckpoint(filepath='/Users/Alex/checkpoint.hdf5') 

model.fit(X=predictor_train, y=target_train, nb_epoch=20, batch_size=batch,validation_split=0.1) #best validation split at 0.1 
model.evaluate(X=predictor_train, y=target_train,batch_size=batch,show_accuracy=True) 

print checkpoint 

客觀地說,我的問題是:

  • 我預計運行代碼後,我會找到一個命名的文件夾/用戶裏面checkpoint.hdf5文件/亞歷克斯,但我沒有。我錯過了什麼?

  • 當我打印checkpoint出我得到的是一個keras.callbacks.ModelCheckpoint object at 0x117471290。有沒有辦法打印我想要的東西?代碼如何看起來像?

你的幫助是非常讚賞:)

回答

8

有兩個問題在此代碼:

  • 您還沒有傳遞迴調模型的擬合方法。這是通過關鍵字參數「callbacks」完成的。
  • 的文件路徑應包含佔位符(如 「{劃時代:02D} - {val_loss:.2f}」。是與str.format由Keras以每個曆元保存到不同的文件中使用

所以,正確的版本應該是這樣的:

checkpoint = ModelCheckpoint(filepath='/Users/Alex/checkpoint-{epoch:02d}-{val_loss:.2f}.hdf5') 

model.fit(X=predictor_train, y=target_train, nb_epoch=20, 
     batch_size=batch,validation_split=0.1, callbacks=[checkpoint]) 

您也可以在分配給該關鍵字的列表中添加其他類型的回調

不幸的是,回調對象不存儲歷史信息,以便它不能從我身上恢復噸。

+0

有沒有辦法讓這個文件在CSV或TXT? hdf5是很難與... – abutremutante

+0

@abutremutante不,和HDF5是非常容易使用h5py,但爲什麼你需要使用它?您可以使用load_weights將權重加載到模型中 –