2017-08-24 24 views
2

在tensorflow從以下6個文件產生的刮擦培訓:Tensorflow:如何.META,。數據和模型的.index文件轉換成一個graph.pb文件

  1. events.out.tfevents。 1503494436.06L7-BRM738
  2. model.ckpt-22480.meta
  3. 關卡
  4. model.ckpt-22480.data-00000-的-00001
  5. model.ckpt-22480.index
  6. graph.pbtxt

我想他們(或僅需要的)轉換成一個文件graph.pb能夠轉移到我的Android應用程序。

我試過腳本freeze_graph.py,但它需要輸入已經輸入的input.pb我沒有。 (我只有前面提到的這6個文件)。如何繼續獲取這一個freezed_graph.pb文件?我看到了幾條線索,但沒有一條爲我工作。

+1

在這裏看到:https://stackoverflow.com/questions/45433231/freezing-a-cnn-tensorflow-model-into-a-pb-file/45437684#45437684 –

+0

你是怎麼得到'graph.pbtxt'的?如果它是你模型的圖形,你可以用'freeze.py'來凍結它。 '.pbtxt'。 – velikodniy

+0

我在完成培訓後在訓練日誌中找到了graph.pbtxt文件。然而,在訓練結束之前,它被挽救了。在以前保存的圖形狀態下檢查它。對於從頭開始的培訓,我使用了腳本:train_image_classifier.py。對於培訓,我使用了我自己的圖片(.jpg),在使用腳本之前,我必須將其轉換爲.tfrecord文件build_image_data.py – Rafal

回答

1

你可以使用這個簡單的腳本來做到這一點。但是您必須指定輸出節點的名稱。

import tensorflow as tf 

meta_path = 'model.ckpt-22480.meta' # Your .meta file 
output_node_names = ['output:0'] # Output nodes 

with tf.Session() as sess: 

    # Restore the graph 
    saver = tf.train.import_meta_graph(meta_path) 

    # Load weights 
    saver.restore(sess,tf.train.latest_checkpoint('.')) 

    # Freeze the graph 
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
     sess, 
     sess.graph_def, 
     output_node_names) 

    # Save the frozen graph 
    with open('output_graph.pb', 'wb') as f: 
     f.write(frozen_graph_def.SerializeToString()) 
+1

有沒有簡單的方法來獲取輸出節點名稱? – rambossa

+0

我正在嘗試做同樣的事情。有沒有辦法找到輸出節點名稱? – blueether

+0

您可以使用[summarize_graph](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms#inspecting-graphs)實用程序。 – velikodniy

2

因爲它可能對別人有幫助,所以我在github上的答案後也回答這個問題;-)。 我想你可以嘗試這樣的事情(在tensorflow /蟒蛇freeze_graph腳本/工具):

python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "

這裏最重要的標誌是--input_binary = false作爲文件graph.pbtxt是文本格式。我認爲它對應於所需的graph.pb,它是二進制格式的等價物。

關於output_node_names,這對我來說真的很讓人困惑,因爲我仍然對這部分有些問題,但是您可以使用tensorflow中的summarize_graph腳本,它可以將pb或pbtxt作爲輸入。

問候,

斯蒂芬

相關問題