2017-08-17 125 views
0

加載張量流模型以測試一些新數據時,我遇到了各種麻煩。當我訓練的模型,我用這個:如何加載經過訓練的張量流模型

save_model_file = 'my_saved_model' 
saver = tf.train.Saver() 
save_path = saver.save(sess, save_model_file) 

這似乎導致下列文件被創建:

my_saved_model.meta 
checkpoint 
my_saved_model.index 
my_saved_model.data-00000-of-00001 

我不知道這些文件的我應該要注意。

現在模型已經過訓練,我似乎無法加載它或在不拋出異常的情況下使用它。下面是我在做什麼:

def neural_net_data_input(data_shape): 
    theshape=(None,)+tuple(data_shape) 
    return tf.placeholder(tf.float32,shape=theshape,name='x') 

def neural_net_label_input(n_out): 
    return tf.placeholder(tf.float32,shape=(None,n_out),name='one_hot_labels') 

def neural_net_keep_prob_input(): 
    return tf.placeholder(tf.float32,name='keep_prob') 

def do_generate_network(x): 
    # 
    # here is where i generate the network layer by layer. 
    # this code works fine so i am not showing it here 
    # 
    pass 

# 
# Now I want to restore the model 
# 
tf.reset_default_graph() 

input_data_shape=(32,32,1) 
final_num_outputs=43 

graph1 = tf.Graph() 
with graph1.as_default(): 
    x = neural_net_data_input(input_data_shape) 
    one_hot_labels = neural_net_label_input(final_num_outputs) 
    keep_prob=neural_net_keep_prob_input() 
    logits = do_generate_network(x) 
    # Name logits Tensor, so that is can be loaded from disk after training 
    logits = tf.identity(logits, name='logits') 
    # 
    # accuracy: we use this for validation testing 
    # 
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(one_hot_labels, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy') 

################################ 
# Evaluate 
################################ 

new_data=myutils.load_pickle_file(SOME_DATA_FILE_NAME) 
new_features=new_data['features'] 
new_one_hot_labels=new_data['labels'] 

print('Evaluating on new data...') 
with tf.Session(graph=graph1) as sess: 
    # Initializing the variables 
    sess.run(tf.global_variables_initializer()) 

    saver.restore(sess,save_model_file) 
    new_acc = sess.run(accuracy, feed_dict={x: new_features, one_hot_labels: new_one_hot_labels, keep_prob: 1.}) 
    print('Testing Accuracy For New Images: {}'.format(new_acc)) 

但是當我這樣做,我得到這個:

TypeError: Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist. The operation, 'save/Const', does not exist in the graph. 

所以,我嘗試移動會議在我的圖是這樣的:

################################ 
# Evaluate 
################################ 

print('Evaluating on web data...') 
with tf.Session() as sess: 

    x = neural_net_data_input(input_data_shape) 
    one_hot_labels = neural_net_label_input(final_num_outputs) 
    keep_prob=neural_net_keep_prob_input() 
    logits = do_generate_network(x) 
    # Name logits Tensor, so that is can be loaded from disk after training 
    logits = tf.identity(logits, name='logits') 
    # 
    # accuracy: we use this for validation testing 
    # 
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(one_hot_labels, 1)) 
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy') 

    sess.run(tf.global_variables_initializer()) 

    my_save_dir="/home/carnd/CarND-Traffic-Sign-Classifier-Project" 
    load_model_meta_file=os.path.join(my_save_dir,"my_saved_model.meta") 
    load_model_path=os.path.join(my_save_dir,"my_saved_model") 
    new_saver = tf.train.import_meta_graph(load_model_meta_file) 
    new_saver.restore(sess, load_model_path) 

    web_acc = sess.run(accuracy, feed_dict={x: web_features, one_hot_labels: web_one_hot_labels, keep_prob: 1.}) 
    print('Testing Accuracy For Web Images: {}'.format(web_acc)) 

現在它運行時不會拋出錯誤,但它打印的準確性結果是0.02!我的訓練數據非常相似,我的準確率達到了95%。所以看來我以某種方式錯誤地加載我的模型。

我在做什麼錯?

+0

我正在使用tensorflow 1.2 – Marc

回答

0

步驟裝載訓練的模型:

  1. 負荷曲線圖: 您可以加載使用tf.train.import_meta_graph()圖。一個例子代碼如下:

    model_path = "my_saved_model" 
    inference_graph = tf.Graph() 
    with tf.Session(graph= inference_graph) as sess: 
        # Load the graph with the trained states 
        loader = tf.train.import_meta_graph(model_path+'.meta') 
        loader.restore(sess, model_path) 
    
  2. 獲得張量:獲得張量需要推論使用get_tensor_by_name()。因此,在您的模型中,確保按名稱命名張量,以便在推理期間調用它。

    #Get the tensors by their variable name 
    
        _accuracy = inference_graph.get_tensor_by_name('accuracy:0') 
        _x = inference_graph get_tensor_by_name('x:0') 
        _y = inference_graph.get_tensor_by_name('y:0') 
    
  3. 測試:可以通過加載的張量不完成。 sess.run(_accuracy, feed_dict={_x: ... , _y:...}

相關問題