0
我正在嘗試訓練一個簡單的神經網絡,其中需要保存模型,加載新數據集並還原模型。此工作保存還原過程在3或4次迭代後消耗我的所有內存。這是我的代碼的相關部分。 runsess()函數在循環中迭代多次。Tensorflow-恢復模型會消耗後續迭代中的所有內存
num_steps = 50
with tf.Session(graph=graph) as session:
tf.global_variables_initializer().run()
print("Initialized")
for step in range(num_steps):
offset = (step * batch_size) % (train_labels.shape[0] - batch_size)
batch_data = train_dataset[offset:(offset + batch_size)]
batch_labels = train_labels[offset:(offset + batch_size)]
feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels}
_, l, predictions = session.run([optimizer, loss, train_prediction], feed_dict=feed_dict)
if (step % 10 == 0):
print("Minibatch loss at step %d: %f" % (step, l))
print("Minibatch accuracy: %.1f%%" % accuracy(predictions, batch_labels))
print("Validation accuracy: %.1f%%\n" % accuracy(valid_prediction.eval(), valid_labels))
print("Test accuracy: %.1f%%" % accuracy(test_prediction.eval(), test_labels))
# Saving the session
saver = tf.train.Saver()
save_path = "./checkpoints/model.ckpt"
saver.save(session, save_path)
print("Model saved in file: %s" % save_path)
session.close()
def runsess(graph,num_steps):
with tf.Session(graph=graph) as session:
saver = tf.train.import_meta_graph('./checkpoints/model.ckpt.meta')
saver.restore(session,tf.train.latest_checkpoint('./checkpoints/'))
tf.global_variables_initializer()
print("Initialized")
tf.get_default_graph()
for step in range(num_steps):
offset = (step * batch_size) % (train_labels.shape[0] - batch_size)
batch_data = train_dataset[offset:(offset + batch_size)]
batch_labels = train_labels[offset:(offset + batch_size)]
feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels}
_, l, predictions = session.run([optimizer, loss, train_prediction], feed_dict=feed_dict)
if (step % 10 == 0):
print("Minibatch loss at step %d: %f" % (step, l))
print("Minibatch accuracy: %.1f%%" % accuracy(predictions, batch_labels))
print("Validation accuracy: %.1f%%\n" % accuracy(valid_prediction.eval(), valid_labels))
print("Test accuracy: %.1f%%" % accuracy(test_prediction.eval(), test_labels))
saver.save(session, save_path)
session.close()
看來我在runsess()中保存模型時犯了一個錯誤,但我不明白在哪裏以及如何。我該如何解決這個問題?
這非常有幫助。完成圖形顯示,每次迭代都會將同一圖形附加到前一個圖形。用Tensorboard製作圖表也顯示了這一點。儘管從循環中刪除保存功能會導致它在上次會話期間失去所有進度。我不清楚如何解決這個問題。 –