2016-12-15 78 views
2

我在神經網絡的張量流中創建了模型。 我保存了模型並將它恢復到另一個python文件中。張量流和預測中的恢復模型

的代碼如下:

def restoreModel(): 
    prediction = neuralNetworkModel(x) 
    tf_p = tensorFlow.nn.softmax(prediction) 
    temp = np.array([2,1,541,161124,3,3]) 
    temp = np.vstack(temp) 

    with tensorFlow.Session() as sess: 
     new_saver = tensorFlow.train.import_meta_graph('model.ckpt.meta') 
     new_saver.restore(sess, tensorFlow.train.latest_checkpoint('./')) 
     all_vars = tensorFlow.trainable_variables() 

     tensorFlow.initialize_all_variables().run() 
     sess.run(tensorFlow.initialize_all_variables()) 
     predict = sess.run([tf_p], feed_dict={ 
      tensorFlow.transpose(x): temp, 
      y : *** 
     }) 

當我想預測 「TEMP」 變量! X是矢量形狀,我「轉置」它以匹配形狀。 我不明白我需要在feed_dict變量中寫什麼。

回答

2

我回答遲到,但也許它仍然可以有用。 feed_dict用於爲張量流提供您希望佔位符采用的值。 fetchesrun的第一個參數)是您想要的結果列表。的feed_dict和按鍵的fetches元素必須是張量的名字(我沒有嘗試,雖然)或變量,你可以通過

graph = tf.get_default_graph() 
var = graph.get_operation_by_name('name_of_operation').outputs[0] 

也許得到graph.get_tensor_by_name('name_of_operation:0')作品也一樣,我沒有嘗試。

默認情況下,佔位符的名稱只是「Placeholder」,「Placeholder_1」等,按照圖形定義中的創建順序。