我曾嘗試使用下面的代碼加載預訓練模型(Model 1
):如何不重新初始化Tensorflow中的預訓練加載模型?
def load_seq2seq_model(sess):
with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f:
saved_args = pickle.load(f)
# Initialize the model with saved args
model = Model1(saved_args)
#Inititalize Tensorflow saver
saver = tf.train.Saver()
# Checkpoint
ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path)
print('Loading model: ', ckpt.model_checkpoint_path)
# Restore the model at the checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
return model
現在,我要培養從零開始另一個模型(Model 2
)將採取Model 1
的輸出。但爲此,我需要定義一個會話並加載預先訓練的模型並初始化模型tf.initialize_all_variables()
。所以,預先訓練的模型也將被初始化。
任何人都可以告訴我如何正確地訓練Model 2
從預先訓練的模型Model 1
輸出後?
我試圖下面給出 -
with tf.Session() as sess:
# Initialize all the variables of the graph
seq2seq_model = load_seq2seq_model(sess)
sess.run(tf.initialize_all_variables())
.... Rest of the training code goes here....
您是否在嘗試初始化之前導入模型1? – Pop
我不知道確切的程序。我試過了。這也在起作用。但如果有人能告訴我正確的程序,我可以肯定。 –