2017-10-04 197 views
1

我曾嘗試使用下面的代碼加載預訓練模型(Model 1):如何不重新初始化Tensorflow中的預訓練加載模型?

def load_seq2seq_model(sess): 


    with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f: 
     saved_args = pickle.load(f) 

    # Initialize the model with saved args 
    model = Model1(saved_args) 

    #Inititalize Tensorflow saver 
    saver = tf.train.Saver() 

    # Checkpoint 
    ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path) 
    print('Loading model: ', ckpt.model_checkpoint_path) 

    # Restore the model at the checkpoint 
    saver.restore(sess, ckpt.model_checkpoint_path) 
    return model 

現在,我要培養從零開始另一個模型(Model 2)將採取Model 1的輸出。但爲此,我需要定義一個會話並加載預先訓練的模型並初始化模型tf.initialize_all_variables()。所以,預先訓練的模型也將被初始化。

任何人都可以告訴我如何正確地訓練Model 2從預先訓練的模型Model 1輸出後?

我試圖下面給出 -

with tf.Session() as sess: 
    # Initialize all the variables of the graph 
    seq2seq_model = load_seq2seq_model(sess) 
    sess.run(tf.initialize_all_variables()) 
    .... Rest of the training code goes here.... 
+0

您是否在嘗試初始化之前導入模型1? – Pop

+0

我不知道確切的程序。我試過了。這也在起作用。但如果有人能告訴我正確的程序,我可以肯定。 –

回答

0

正在使用金丹不需要初始化恢復所有變量。因此,不使用tf.initialize_all_variables(),您可以使用tf.variables_initializer(var_list)僅初始化第二個網絡的權重。

要獲得第二網絡的總權重的列表,你可以創建一個變量範圍Model 2網絡:

with tf.variable_scope("model2"): 
    model2 = Model2(...) 

然後使用

model_2_variables_list = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, 
    scope="model2" 
) 

得到的變量列表Model 2網絡。最後,您可以創建第二個網絡的初始化程序:

init2 = tf.variables_initializer(model_2_variables_list) 

with tf.Session() as sess: 
    # Initialize all the variables of the graph 
    seq2seq_model = load_seq2seq_model(sess) 
    sess.run(init2) 
    .... Rest of the training code goes here.... 
+0

NotFoundError即將到來。似乎,它也試圖加載模型2。 –

+0

改變加載的順序解決了我猜的問題! –

+0

@AvijitDasgupta運行'saver.restore(sess,ckpt.model_checkpoint_path)'時是否顯示NotFoundError?如果是,請確保在初始化第二個網絡之前初始化保存器('saver = tf.train.Saver()')*(在'model = Model1(...)'之後但在'model = Model2 ...)')。否則,保存器將嘗試加載導致NotFoundError的兩個網絡。 – BlueSun

相關問題