2017-04-30 138 views
-1

我正在使用tf.train.Saversaverestore保存和恢復TensorFlow模型。在恢復過程中,我正在加載新的輸入數據。該restore方法拋出這個錯誤:TensorFlow變量名稱 - 保存/恢復中的分配錯誤

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [1334,3] rhs shape= [1246,3] [[Node: save/Assign_6 = Assign[T=DT_FLOAT, _class=["loc:@Variable_2"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_2, save/RestoreV2_6)]]

這似乎是說,問題是出在Variable_2,但一個人如何確定哪些變量的代碼對應於Variable_2

回答

-1

當您創建一個新的變量得到它是一個獨特的名字。 Saver.restore在檢查點查看同名。如果你需要一些初始化的變量來自不同的關卡有不同的名稱,請看看tf.contrib.framework.init_from_checkpoint

+0

謝謝,但我不太以下錯誤。我使用的是保存檢查點來加載檢查點的相同代碼;在保存和恢復之間沒有創建新的變量。 –

0
  • 如果要恢復的模式,這樣就前饋,則該模型的形狀和型號杜彥武應該是一樣的,當你保存它
  • 所以上面的錯誤是,當你正在恢復的模型中的一個說其中保存有張形狀[1246,3],但您要指派給一個張量,其形狀爲[1334,3]
  • 明確知道哪些變量是指的名字,你可以指定唯一的名稱張量,例如a = tf.placeholder("float", [3, 3], name="tensor_a")
  • 所以,現在在恢復模式,你知道你的模型與NAME =「tensor_a」,這是3倍形狀的圖的張量3
  • 快速教程在代碼:

    # Create some variables. 
    v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer) 
    v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer) 
    
    inc_v1 = v1.assign(v1+1) 
    dec_v2 = v2.assign(v2-1) 
    
    # Add an op to initialize the variables. 
    init_op = tf.global_variables_initializer() 
    
    # Add ops to save and restore all the variables. 
    saver = tf.train.Saver() 
    
    # Later, launch the model, initialize the variables, do some work, and save the 
    # variables to disk. 
    with tf.Session() as sess: 
        sess.run(init_op) 
        # Do some work with the model. 
        inc_v1.op.run() 
        dec_v2.op.run() 
        # Save the variables to disk. 
        save_path = saver.save(sess, "/tmp/model.ckpt") 
        print("Model saved in file: %s" % save_path) 
    
    tf.reset_default_graph() 
    
    # Create some variables. 
    d1 = tf.get_variable("v1", shape=[3]) 
    d2 = tf.get_variable("v2", shape=[5]) 
    
    # Add ops to save and restore all the variables. 
    saver = tf.train.Saver() 
    
    # Later, launch the model, use the saver to restore variables from disk, and 
    # do some work with the model. 
    with tf.Session() as sess: 
        # Restore variables from disk. 
        saver.restore(sess, "/tmp/model.ckpt") 
        print("Model restored.") 
        # Check the values of the variables 
        print("v1 : %s" % d1.eval()) 
        print("v2 : %s" % d2.eval()) 
    
  • 如果你在上面的代碼D1注意到且v1具有相同的形狀,現在如果你改變任何國稅發可變形狀會扔給你一個錯誤,它類似於你越來越