我試圖使用Estimators
而不是自行實施訓練循環。我正在使用MNIST數據上的自動編碼器。我有一個training_model_fn
函數來構建包含輸入,模型,損失,優化器和摘要的訓練模型。我可以訓練它,一切都很順利,但是當我試圖只加載解碼器部分時 - 它失敗了。從model_dir部分加載tf.contrib.learn.Estimators(僅在自動編碼器設置中加載解碼器權重)
我希望解碼器模型得到encoded
矢量作爲輸入,並運行通過網絡的相同解碼部分(與以前學習的權重)在最後生成decoded
圖像。
我已經創建了一個與訓練一個分享一些代碼,只創建模型的相關部分的另一個decoded_model_fn
功能,但是當我試圖加載Estimator
有:
est = tf.contrib.learn.Estimator(model_fn=decoder_model_fn, model_dir=...)
est.predict(input_fn=...)
我得到以下錯誤:
...
NotFoundError: Key ... not found in checkpoint ...
...
我認爲Estimator
試圖從檢查點加載的所有變量,顯然我的解碼器模型沒有包含所有的人。
有誰知道我可以從存儲的會話中部分加載變量嗎?我期望ignore_unknowns
標誌,但找不到任何類似的東西。
任何我應該如何使用Estimator
自動編碼器模型的例子?