我試圖使用遷移學習方法。下面是代碼的快照,其中我的代碼是學習在訓練數據:問題與Tensorflow保存和恢復模型
max_accuracy = 0.0
saver = tf.train.Saver()
for epoch in range(epocs):
shuffledRange = np.random.permutation(n_train)
y_one_hot_train = encode_one_hot(len(classes), Y_input)
y_one_hot_validation = encode_one_hot(len(classes), Y_validation)
shuffledX = X_input[shuffledRange,:]
shuffledY = y_one_hot_train[shuffledRange]
for Xi, Yi in iterate_mini_batches(shuffledX, shuffledY, mini_batch_size):
sess.run(train_step,
feed_dict={bottleneck_tensor: Xi,
ground_truth_tensor: Yi})
# Every so often, print out how well the graph is training.
is_last_step = (i + 1 == FLAGS.how_many_training_steps)
if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
train_accuracy, cross_entropy_value = sess.run(
[evaluation_step, cross_entropy],
feed_dict={bottleneck_tensor: Xi,
ground_truth_tensor: Yi})
validation_accuracy = sess.run(
evaluation_step,
feed_dict={bottleneck_tensor: X_validation,
ground_truth_tensor: y_one_hot_validation})
print('%s: Step %d: Train accuracy = %.1f%%, Cross entropy = %f, Validation accuracy = %.1f%%' %
(datetime.now(), i, train_accuracy * 100, cross_entropy_value, validation_accuracy * 100))
result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name))
probs = sess.run(result_tensor,feed_dict={'pool_3/_reshape:0': Xi[0].reshape(1,2048)})
if validation_accuracy > max_accuracy :
saver.save(sess, 'models/superheroes_model')
max_accuracy = validation_accuracy
print(probs)
i+=1
這裏是我的代碼,我在哪裏加載模型:
def load_model() :
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('models/superheroes_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('models/'))
sess.run(tf.global_variables_initializer())
result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name))
X_feature = features[0].reshape(1,2048)
probs = sess.run(result_tensor,
feed_dict={'pool_3/_reshape:0': X_feature})
print probs
return sess
所以現在對同一數據點我在訓練和測試中獲得完全不同的結果。它不甚密切。在測試過程中,由於我有4個班級,因此我的概率接近25%。但在訓練期間,最高級別的概率是90%。
是否有任何問題,同時保存或恢復的模式?
解決的問題1/4概率的初始值代替。我正在訓練大量的時刻,所以在一些時間過後,可能性會下降。 – neel