2016-02-21 289 views
3

我做Tensorflow一些培訓,並使用保護拯救整個過程:tensorflow僅恢復變量

# ... define model 

# add a saver 
saver = tf.train.Saver() 

# ... run a session 
    # .... 
    # save the model 
    save_path = saver.save(sess,fileSaver) 

它工作正常,我可以使用完全相同的模型,並呼籲成功地還原整個過程:

saver.restore(sess, importSaverPath) 

現在我想只修改優化,同時保持模型常數的其餘部分(計算圖從優化保持不變分開):

# optimizer used before 
# optimizer = tf.train.AdamOptimizer 
# (learning_rate = learningRate).minimize(costPrediction) 
# the new optimizer I want to use 
optimizer = tf.train.RMSPropOptimizer 
    (learning_rate = learningRate, decay = 0.9, momentum = 0.1, 
    epsilon = 1e-5).minimize(costPrediction) 

我也想繼續從我保存的最後圖狀態(即,我想恢復我的變量的狀態並繼續使用另一個訓練算法)進行訓練。當然我不能使用:

saver.restore 

不再,因爲圖形已經改變。

所以我的問題是:有沒有一種方法可以在整個會話被保存時使用saver.restore命令恢復變量(或者甚至可能在以後使用,只有一部分變量)?我在API文檔和在線上查找了這樣的功能,但找不到任何可以幫助我實現它的示例/足夠詳細的解釋。

+0

考慮在圖中包含兩個優化器。然後你可以選擇在會議中打電話。切換優化器不需要保存/恢復,所以我想我會提到這一點,以確保您瞭解這一點。 – user728291

回答

2

通過將變量列表作爲var_list參數傳遞給構造函數Saver,可以恢復變量的子集。但是,當您更改優化程序時,可能已創建了其他變量(例如,動量累加器)以及與以前的優化程序相關的變量(如果有的話)已從模型中刪除。因此,僅使用舊的Saver對象進行恢復將不起作用,特別是如果使用默認構造函數構造它時,該構造函數使用tf.all_variables作爲var_list參數的參數。您必須在您的模型中創建的變量的子集上構建Saver對象,然後restore才能起作用。請注意,這會使新優化器創建的新變量未初始化,因此您必須顯式初始化它們。

1

我看到了同樣的問題。受凱文曼的回答啓發。我的解決方案是:

  1. 定義您的新圖形(這裏只有新的優化器相關變量與舊圖形不同)。

  2. 使用tf.global_variables()獲取所有變量。這將返回一個var list,我稱之爲g_vars。

  3. 使用tf.contrib.framework.get_variables_by_suffix('some variable filter')獲取所有與優化器相關的變量。過濾器可能是RMSProp或RMSPRrop_ *。這個函數返回一個名爲exclude_vars的var列表。

  4. 獲取g_vars中的變量,但不包含在exclude_vars中。只需[在exclude_vars項目的項目在g_vars如果項目不]使用

    瓦爾=

這些增值經銷商在新的和舊的圖表,你可以從舊的模式,現在恢復常見瓦爾。

+0

init_ckpt = tf.contrib.framework.assign_from_checkpoint_fn(check_point,vars) init_ckpt(sess) –

相關問題