2017-08-20 75 views
1

更新:我發現下面的代碼在使用tensorflow-cpu時能夠正常工作。使用tensorflow-gpu時,問題僅存在。我怎樣才能使它工作?在Tensorflow中保存模型不能在GPU下工作?

我在我的代碼中找不到問題 - 我試圖保存我的變量,然後重新加載它們,並且它們看起來不會從保存的模型加載。

我會注意到,如果我在同一個python運行中進行保存和加載(沒有過程結束並運行測試腳本),它們會加載。我的問題是,當我訓練模式時,這不起作用 - >保存 - >處理結束 - >用測試標誌再次運行腳本 - >模型沒有錯誤地加載,但結果好像不是。

代碼:

運行#1

# creating LSTM model... 

with tf.Session() as sess: 
    saver = tf.train.Saver() 

    # training... 

    save_path = saver.save(sess, "./saved_models/model.ckpt") 
    print("Model saved in file: %s" % save_path) 

運行#2

# creating the same exact LSTM model... 

with tf.Session() as sess: 
    saver = tf.train.Saver() 

    saver.restore(sess, "./saved_models/model.ckpt") 
    print("Model restored.") 

    # testing... 

如果我跑這兩個片段背靠背,我得到所需的輸出 - 模型訓練預測一個微不足道的序列,並在測試過程中正確預測它。如果我分別運行這兩個片段,則模型會在測試期間預測錯誤的序列。

更新:我被建議嘗試導入MetaGraph,它也不工作。代碼:

運行#1

# creating model... 

tf.add_to_collection('a', net.a) 
# adding nodes ... 
tf.add_to_collection('z', net.z) 

with tf.Session() as sess: 
    saver = tf.train.Saver() 
    # training... 
    save_path = saver.save(sess, "./saved_models/my-model") 
    print("Model saved in file: %s" % save_path) 

運行#2

with tf.Session() as sess: 
    new_saver = tf.train.import_meta_graph('./saved_models/my-model.meta') 
    new_saver.restore(sess, './saved_models/my-model') 

    net.a = tf.get_collection('a')[0] 
    # adding nodes ... 
    net.z = tf.get_collection('z')[0] 

    # testing... 

上面的代碼運行正常 - 但測試集結果發現這並不是培訓後(並再次,如果我跑這兩個片段在同一個Python實例中,它工作正常)。

這應該是相當微不足道的,我只是不能得到它的工作。歡迎任何幫助。具體來說,我並不需要保存整個圖表 - 只是變量(其中一些在LSTM單元內)。

+0

嘗試導入元圖(tf.train.import_meta_graph)。 https://www.tensorflow.org/programmers_guide/meta_graph –

+0

這也不適用於我。在問題中添加了相關的代碼... –

回答

1

我遇到了同樣的問題,我想你使用tf.Variable()吧? 嘗試將其更改爲tf.get_variable()。它對我有用:)

+0

我正在使用get_variable –