2017-06-26 50 views
2

我一直在使用tensorflow的LinearClassifier()班培養了邏輯迴歸模型的模型,並設置model_dir參數,它指定的位置在哪裏模型訓練過程中保存檢查站的metagrahps :如何從tensorflow高層API恢復訓練的LinearClassifier並作出預測

# Create temporary directory where metagraphs will evenually be saved 
model_dir = tempfile.mkdtemp() 

logistic_model = tf.contrib.learn.LinearClassifier(
    feature_columns=feature_columns, 
    n_classes=num_labels, model_dir=model_dir) 

我一直在閱讀有關從metagraphs恢復模式,但沒有發現任何有關如何使用高級API創建的模型這樣做。 LinearClassifier()的預測()函數,但我不能找到如何使用已通過關卡元圖恢復模型的實例來運行預測任何文件。我會如何去做這件事?一旦模型被恢復,我的理解是,我有tf.Sess對象,它缺乏所有建在LinearClassifier類的功能,像這樣的工作:

with tf.Session() as sess: 
    new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 
    new_saver.restore(sess, 'my-save-dir/my-model-10000') 
    # Run prediction algorithm... 

如何運行相同的預測所使用的高級API,使一個恢復模型預測算法?有沒有更好的方法來解決這個問題?

感謝您的輸入。

+0

建議的解決方案適合您嗎? –

+0

潛在的,我實現了你的建議的修復,但仍需要確認它的工作原理。現在超級淹沒,當我有機會時會報告回來。謝謝你的幫助。 –

回答

1

LinearClassifier()有「model_dir」 PARAM,如果在指向一個訓練模型將恢復模型。
在訓練過程中,你做:

logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir) 
classifier.fit(X_train, y_train, steps=10) 

在推理,LinearClassifier()將加載從給出的路徑訓練模型,並且不使用fit()方法,但調用predict()方法:

logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir) 
y_pred = classifier.predict(X_test) 
相關問題