2017-07-30 45 views
0

我正在嘗試訓練一個簡單的神經網絡,其中需要保存模型,加載新數據集並還原模型。此工作保存還原過程在3或4次迭代後消耗我的所有內存。這是我的代碼的相關部分。 runsess()函數在循環中迭代多次。Tensorflow-恢復模型會消耗後續迭代中的所有內存

num_steps = 50 

with tf.Session(graph=graph) as session: 
    tf.global_variables_initializer().run() 
    print("Initialized") 
    for step in range(num_steps): 
    offset = (step * batch_size) % (train_labels.shape[0] - batch_size) 
    batch_data = train_dataset[offset:(offset + batch_size)] 
    batch_labels = train_labels[offset:(offset + batch_size)] 
    feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels} 

    _, l, predictions = session.run([optimizer, loss, train_prediction], feed_dict=feed_dict) 

    if (step % 10 == 0): 
     print("Minibatch loss at step %d: %f" % (step, l)) 
     print("Minibatch accuracy: %.1f%%" % accuracy(predictions, batch_labels)) 
     print("Validation accuracy: %.1f%%\n" % accuracy(valid_prediction.eval(), valid_labels)) 


    print("Test accuracy: %.1f%%" % accuracy(test_prediction.eval(), test_labels)) 

    # Saving the session 
    saver = tf.train.Saver() 
    save_path = "./checkpoints/model.ckpt" 
    saver.save(session, save_path) 
    print("Model saved in file: %s" % save_path) 
    session.close() 


def runsess(graph,num_steps): 
    with tf.Session(graph=graph) as session: 
    saver = tf.train.import_meta_graph('./checkpoints/model.ckpt.meta') 
    saver.restore(session,tf.train.latest_checkpoint('./checkpoints/')) 
    tf.global_variables_initializer() 
    print("Initialized") 
    tf.get_default_graph() 
    for step in range(num_steps): 
     offset = (step * batch_size) % (train_labels.shape[0] - batch_size) 
     batch_data = train_dataset[offset:(offset + batch_size)] 
     batch_labels = train_labels[offset:(offset + batch_size)] 
     feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels} 

     _, l, predictions = session.run([optimizer, loss, train_prediction], feed_dict=feed_dict) 
     if (step % 10 == 0): 
     print("Minibatch loss at step %d: %f" % (step, l)) 
     print("Minibatch accuracy: %.1f%%" % accuracy(predictions, batch_labels)) 
     print("Validation accuracy: %.1f%%\n" % accuracy(valid_prediction.eval(), valid_labels)) 
    print("Test accuracy: %.1f%%" % accuracy(test_prediction.eval(), test_labels)) 

    saver.save(session, save_path) 
    session.close() 

看來我在runsess()中保存模型時犯了一個錯誤,但我不明白在哪裏以及如何。我該如何解決這個問題?

回答

0

問題是你調用的功能在你的圖形中創建新的操作,從而導致內存消耗。

首先,您應該有一個保護程序,您不能在循環中創建它,每個保存程序將創建分配操作。

saver = tf.train.import_meta_graph('./checkpoints/model.ckpt.meta') 
saver.restore(session,tf.train.latest_checkpoint('./checkpoints/')) 

所有的二,global_variables_initializer創建一個運算(OPS實際上很多)也一樣,但你怎麼稱呼它,甚至不存儲結果。儘管有這個名字,這個函數不會初始化變量 - 它創建了一個操作 - 只是在你的情況下將其刪除。

tf.global_variables_initializer() 

最簡單的方式瞭解其功能做修改您的圖形或不就是總是完成您的圖形一旦你完成建造它。這樣每次你調用一個創建新操作的函數 - 你都會得到一個異常告訴你,你就可以調試它。

tf.get_default_graph().finalize() 
+0

這非常有幫助。完成圖形顯示,每次迭代都會將同一圖形附加到前一個圖形。用Tensorboard製作圖表也顯示了這一點。儘管從循環中刪除保存功能會導致它在上次會話期間失去所有進度。我不清楚如何解決這個問題。 –