2017-08-16 128 views
4

我發現了2種在Tensorflow中保存模型的方法:tf.train.Saver()SavedModelBuilder。但是,我無法找到使用型號加載第二種方式後的文檔。如何在tensorflow上加載和使用保存的模型?

注意:我想使用SavedModelBuilder的方式,因爲我在Python中訓練模型,並將使用它在另一種語言(Go)的服務時間,似乎SavedModelBuilder是在這種情況下唯一的方法。

這與tf.train.Saver()(第一種方式)的偉大工程:

model = tf.add(W * x, b, name="finalnode") 

# save 
saver = tf.train.Saver() 
saver.save(sess, "/tmp/model") 

# load 
saver.restore(sess, "/tmp/model") 

# IMPORTANT PART: REALLY USING THE MODEL AFTER LOADING IT 
# I CAN'T FIND AN EQUIVALENT OF THIS PART IN THE OTHER WAY. 

model = graph.get_tensor_by_name("finalnode:0") 
sess.run(model, {x: [5, 6, 7]}) 

tf.saved_model.builder.SavedModelBuilder()Readme定義,但加載與tf.saved_model.loader.load(sess, [], export_dir)模型)後,我不能在節點找回找到文檔(見"finalnode"在上面的代碼)

回答

4

少了什麼是signature

# Saving 
builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 
builder.add_meta_graph_and_variables(sess, ["tag"], signature_def_map= { 
     "model": tf.saved_model.signature_def_utils.predict_signature_def(
      inputs= {"x": x}, 
      outputs= {"finalnode": model}) 
     }) 
builder.save() 

# loading 
with tf.Session(graph=tf.Graph()) as sess: 
    tf.saved_model.loader.load(sess, ["tag"], export_dir) 
    graph = tf.get_default_graph() 
    x = graph.get_tensor_by_name("x:0") 
    model = graph.get_tensor_by_name("finalnode:0") 
    print(sess.run(model, {x: [5, 6, 7, 8]})) 
0

Tensorflow的優選建築物的方式,並使用不同的語言模型是tensorflow serving

現在在您的情況下,您正在使用saver.save來保存模型。這樣可以節省一個meta文件,ckpt文件和一些其他文件來保存權重和網絡信息,訓練步驟等。這是您在訓練時保存的首選方式。

如果您現在已經完成了培訓,您應該使用SavedModelBuilder從您保存的文件中凍結圖表saver.save。該凍結圖包含一個pb文件幷包含所有網絡和權重。

這個冷凍模型應該用於tensorflow serving服務,然後其他語言可以使用該模型使用gRPC協議。

整個程序描述在this優秀教程。

+0

感謝您的回答和鏈接,但不回答這麼多,我的問題... – Thomas

+0

的鏈接*不*有答案的地方後,「最後的一步 - 保存模型「,但只有在您已經知道去哪裏看的時候才能找到它......它可以絕對更簡潔,但也要感謝鏈接和洞察力 –

相關問題