2017-08-02 444 views

回答

9

Keras本身並不包含任何手段來導出TensorFlow圖作爲協議緩衝區文件,但可以使用常規TensorFlow事業去做。 Here是博客文章,解釋瞭如何使用包含在TensorFlow,這是「典型」的方式是做該實用程序腳本freeze_graph.py做到這一點。

不過,我個人認爲不得不作出一個檢查點,然後運行一個外部腳本來得到一個模型,而是更願意從我自己的Python代碼做糟,所以我用這樣的功能:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): 
    """ 
    Freezes the state of a session into a pruned computation graph. 

    Creates a new computation graph where variable nodes are replaced by 
    constants taking their current value in the session. The new graph will be 
    pruned so subgraphs that are not necessary to compute the requested 
    outputs are removed. 
    @param session The TensorFlow session to be frozen. 
    @param keep_var_names A list of variable names that should not be frozen, 
          or None to freeze all the variables in the graph. 
    @param output_names Names of the relevant graph outputs. 
    @param clear_devices Remove the device directives from the graph for better portability. 
    @return The frozen graph definition. 
    """ 
    from tensorflow.python.framework.graph_util import convert_variables_to_constants 
    graph = session.graph 
    with graph.as_default(): 
     freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) 
     output_names = output_names or [] 
     output_names += [v.op.name for v in tf.global_variables()] 
     input_graph_def = graph.as_graph_def() 
     if clear_devices: 
      for node in input_graph_def.node: 
       node.device = "" 
     frozen_graph = convert_variables_to_constants(session, input_graph_def, 
                 output_names, freeze_var_names) 
     return frozen_graph 

這在freeze_graph.py實施的啓發。這些參數也與腳本類似。 session是TensorFlow會話對象。 keep_var_names如果你想保持一些變量沒有被凍結(例如,對於狀態模型)時才需要,所以一般不會。 output_names是一個帶有產生所需輸出的操作名稱的列表。 clear_devices只是刪除任何設備指令,使圖形更便攜。因此,對於一個典型的Keras model與一個輸出,你會做這樣的事情:

from keras import backend as K 

# Create, compile and train model... 

frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name]) 

然後您可以將圖形寫入文件像往常一樣tf.train.write_graph

tf.train.write_graph(frozen_graph, "some_directory", "my_model.pb", as_text=False) 
1

的freeze_session方法工作正常。但與保存到檢查點文件相比,使用TensorFlow附帶的freeze_graph工具似乎更簡單,因爲它更易於維護。所有你需要做的是以下兩個步驟:

首先,添加您的Keras代碼model.fit(...)後,培養你的模型:

from keras import backend as K 
import tensorflow as tf 
print(model.output.op.name) 
saver = tf.train.Saver() 
saver.save(K.get_session(), '/tmp/keras_model.ckpt') 

然後cd到您的TensorFlow根目錄下,運行:

python tensorflow/python/tools/freeze_graph.py \ 
--input_meta_graph=/tmp/keras_model.ckpt.meta \ 
--input_checkpoint=/tmp/keras_model.ckpt \ 
--output_graph=/tmp/keras_frozen.pb \ 
--output_node_names="<output_node_name_printed_in_step_1>" \ 
--input_binary=true 
相關問題