2017-06-03 211 views
1

我試圖使用遷移學習方法。下面是代碼的快照,其中我的代碼是學習在訓練數據:問題與Tensorflow保存和恢復模型

max_accuracy = 0.0 
    saver = tf.train.Saver() 
    for epoch in range(epocs): 
     shuffledRange = np.random.permutation(n_train) 
     y_one_hot_train = encode_one_hot(len(classes), Y_input) 
     y_one_hot_validation = encode_one_hot(len(classes), Y_validation) 
     shuffledX = X_input[shuffledRange,:] 
     shuffledY = y_one_hot_train[shuffledRange] 
     for Xi, Yi in iterate_mini_batches(shuffledX, shuffledY, mini_batch_size): 
      sess.run(train_step, 
        feed_dict={bottleneck_tensor: Xi, 
           ground_truth_tensor: Yi}) 
      # Every so often, print out how well the graph is training. 
      is_last_step = (i + 1 == FLAGS.how_many_training_steps) 
      if (i % FLAGS.eval_step_interval) == 0 or is_last_step: 
       train_accuracy, cross_entropy_value = sess.run(
        [evaluation_step, cross_entropy], 
        feed_dict={bottleneck_tensor: Xi, 
          ground_truth_tensor: Yi}) 
       validation_accuracy = sess.run(
        evaluation_step, 
        feed_dict={bottleneck_tensor: X_validation, 
          ground_truth_tensor: y_one_hot_validation}) 
       print('%s: Step %d: Train accuracy = %.1f%%, Cross entropy = %f, Validation accuracy = %.1f%%' % 
        (datetime.now(), i, train_accuracy * 100, cross_entropy_value, validation_accuracy * 100)) 
       result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name)) 
       probs = sess.run(result_tensor,feed_dict={'pool_3/_reshape:0': Xi[0].reshape(1,2048)}) 
       if validation_accuracy > max_accuracy : 
        saver.save(sess, 'models/superheroes_model') 
        max_accuracy = validation_accuracy 
        print(probs) 
      i+=1 

這裏是我的代碼,我在哪裏加載模型:

def load_model() : 
    sess=tf.Session()  
    #First let's load meta graph and restore weights 
    saver = tf.train.import_meta_graph('models/superheroes_model.meta') 
    saver.restore(sess,tf.train.latest_checkpoint('models/')) 
    sess.run(tf.global_variables_initializer()) 
    result_tensor = sess.graph.get_tensor_by_name(ensure_name_has_port(FLAGS.final_tensor_name)) 
    X_feature = features[0].reshape(1,2048)   
    probs = sess.run(result_tensor, 
         feed_dict={'pool_3/_reshape:0': X_feature}) 
    print probs 
    return sess 

所以現在對同一數據點我在訓練和測試中獲得完全不同的結果。它不甚密切。在測試過程中,由於我有4個班級,因此我的概率接近25%。但在訓練期間,最高級別的概率是90%。
是否有任何問題,同時保存或恢復的模式?

+0

解決的問題1/4概率的初始值代替。我正在訓練大量的時刻,所以在一些時間過後,可能性會下降。 – neel

回答

2

要小心 - 你調用

saver.restore(sess,tf.train.latest_checkpoint('models/')) 

我以前就做過類似的呼籲後

sess.run(tf.global_variables_initializer()) 

,我認爲重置所有訓練的權重/偏見/等。在恢復的模型中。

如果必須,調用初始化恢復模型之前,如果你需要初始化具體從恢復模型的東西,做獨立。

+0

它確實有幫助,但概率僅從0.25增加到了0.3。我檢查了所有的訓練點。 – neel

+0

我跟着這個保存和恢復模型https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model – neel

+0

你爲什麼不你之前記下該重點型號值保存它(重量/偏差/等),並將它們與恢復後得到的結果進行比較?我之前做過這個練習,對於一些變量作爲一個完整的檢查,對我來說它是可以的,但是誰知道。 –

2

刪除sess.run(tf.global_variables_initializer())在功能load_model,如果你這樣做,你的所有訓練的參數將與該會爲每個類

+0

它的確從0.25到只有0.3的幫助,但概率增大。我檢查了所有的訓練點。 – neel

+0

我跟着這個保存和恢復模型https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model – neel