2017-07-23 24 views
0

我試圖使用Estimators而不是自行實施訓練循環。我正在使用MNIST數據上的自動編碼器。我有一個training_model_fn函數來構建包含輸入,模型,損失,優化器和摘要的訓練模型。我可以訓練它,一切都很順利,但是當我試圖只加載解碼器部分時 - 它失敗了。從model_dir部分加載tf.contrib.learn.Estimators(僅在自動編碼器設置中加載解碼器權重)

我希望解碼器模型得到encoded矢量作爲輸入,並運行通過網絡的相同解碼部分(與以前學習的權重)在最後生成decoded圖像。

我已經創建了一個與訓練一個分享一些代碼,只創建模型的相關部分的另一個decoded_model_fn功能,但是當我試圖加載Estimator有:

est = tf.contrib.learn.Estimator(model_fn=decoder_model_fn, model_dir=...) 
est.predict(input_fn=...) 

我得到以下錯誤:

...  
NotFoundError: Key ... not found in checkpoint ... 
... 

我認爲Estimator試圖從檢查點加載的所有變量,顯然我的解碼器模型沒有包含所有的人。

有誰知道我可以從存儲的會話中部分加載變量嗎?我期望ignore_unknowns標誌,但找不到任何類似的東西。

任何我應該如何使用Estimator自動編碼器模型的例子?

回答

0

OK ...回答自己的情況下,任何人得到這個在未來:

在我的案件的主要問題是,我錯誤地沒有把解碼器在同一variable_scope作爲大模型。變量沒有以相同的名稱被調用,因此未能加載。

發生在我身上的另一件事情,可能會導致問題 - 當我訓練原始模型時,我給它加上了numpy數組,我確保將其投射到float32,這樣權重被存儲爲float32。當我提供解碼器模型時,我使用了一些虛擬np陣列,它們是float64,因此tensorflow抱怨說它期望的數據比它在檢查點數據中的要多。花了我一些時間來弄清楚爲什麼會發生這種事......

相關問題