2016-12-30 47 views
1

考慮下面的代碼:Tensorflow:節能resoring會議 - 多個變量

import tensorflow as tf 

with tf.Session() as sess: 
    var = tf.Variable(42, name='var') 
    sess.run(tf.global_variables_initializer()) 
    tf.train.export_meta_graph('file.meta') 

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('file.meta') 
    print sess.run(var) 

我在說ValueError: At least two variables have the same name: varsaver = tf.train.import_meta_graph('file.meta')得到一個錯誤。

我怎樣才能解決這個問題?無論如何在導入metagraph時覆蓋計算圖?

編輯:

我下面的代碼已經到了:

import tensorflow as tf 

file_name = "./file" 

with tf.Session() as sess: 
    var = tf.Variable(42, name='my_var') 
    sess.run(tf.global_variables_initializer()) 

    saver = tf.train.Saver() 
    saver.save(sess,file_name) 
    saver.export_meta_graph(file_name + '.meta') 

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph(file_name + '.meta') 
    saver.restore(sess, file_name) 
    print(sess.run(var)) 

    # new code that fails: 
    saver = tf.train.Saver() 
    saver.save(sess,file_name) 
    saver.export_meta_graph(file_name + '.meta') 

此打印正確的價值var,但是當我來救圖中的第二次,我得到的相同的原始錯誤:ValueError: At least two variables have the same name: var

+0

嘿@湯姆,你對我的回答滿意嗎? – martianwars

+0

嘿@martianwars,看我的編輯 – Tom

+0

嘿@martianwars編輯包括'reset_default_graph' – Tom

回答

2

在這種情況下,您正在將變量加載到已定義變量的默認圖形中。因此, 您需要在導入之前重置TensorFlow圖。

使用tf.reset_default_graph()執行此操作。在導入之前。查看Exporting and Importing a MetaGraph下的「在默認圖表中導入」部分。

當然,您將不得不使用tf.get_variable()重新定義變量var。試試這個代碼,

import tensorflow as tf 

with tf.Session() as sess: 
    var = tf.Variable(42, name='var') 
    sess.run(tf.global_variables_initializer()) 
    tf.train.export_meta_graph('file.meta') 
tf.reset_default_graph() 
with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('file.meta') 
    var = tf.global_variables()[0] 
    sess.run(tf.initialize_all_variables()) 
    print sess.run(var) 

您的中間代碼是不工作的原因是,tf.get_variable()正在創造它正在隨機初始化一個新的變量。首先確保你首先執行tf.get_variable_scope().reuse_variables()。 看看Understanding tf.get_variable()

不幸的是,使用tf.Variable()創建的變量不能直接與tf.get_variable()重複使用。看看這個comment和這個comment確切地知道爲什麼。因此,如果您希望將來重用該變量,則需要使用tf.get_variable()創建變量。