2016-10-01 66 views
5

我想訓練我的Tensorflow模型,凍結一個快照,然後以前饋模式(無需進一步培訓)使用新的輸入數據運行它。問題:Tensorflow Metagraph基礎知識

  1. tf.train.export_meta_graphtf.train.import_meta_graph正確的工具嗎?
  2. 我需要在collection_list中包含我想要包含在快照中的所有變量的名稱嗎? (對我來說最簡單的就是包含所有內容)
  3. Tensorflow文檔說:「如果沒有指定collection_list,則將導出模型中的所有集合。」這是否意味着如果我在collection_list中沒有指定變量,那麼模型中的所有變量都會導出,因爲它們處於默認集合中?
  4. Tensorflow文檔說:「爲了使Python對象序列化到MetaGraphDef或從MetaGraphDef序列化,Python類必須實現to_proto()和from_proto()方法,並使用register_proto_function將它們註冊到系統」那意味着to_proto()from_proto()必須僅添加到我已定義並希望導出的類中?如果我只使用標準的Python數據類型(int,float,list,dict),那麼這是不是相關的?

在此先感謝。

回答

2

有點晚了,但我仍會試着回答。

  1. tf.train.export_meta_graphtf.train.import_meta_graph正確的工具呢?

我會這麼說。請注意,當您通過tf.train.Saver保存模型時,將隱式調用tf.train.export_meta_graph。要點是:

# create the model 
... 
saver = tf.train.Saver() 
with tf.Session() as sess: 
    ... 
    # save graph and variables 
    # if you are using global_step, the saver will automatically keep the n=5 latest checkpoints 
    saver.save(sess, save_path, global_step) 

然後恢復:

save_path = ... 
latest_checkpoint = tf.train.latest_checkpoint(save_path) 
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta') 
with tf.Session() as sess: 
    saver.restore(sess, latest_checkpoint) 

注意,不是調用tf.train.import_meta_graph你也可以叫你用來創建擺在首位模型的原始代碼段。但是,我認爲使用import_meta_graph這種方式更優雅,即使您無法訪問創建它的代碼,也可以恢復模型。


  • 我需要包括在collection_list,我想包含在快照中所有變量的名字嗎? (最簡單的對我來說將是包括一切。)
  • 號但是問題是有點混亂:在export_meta_graphcollection_list並不意味着是一個變量列表,但集合(即列表字符串鍵)。

    集合非常方便,例如,所有訓練的變量都通過調用自動包含在集合tf.GraphKeys.TRAINABLE_VARIABLES,你可以得到:

    tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 
    

    tf.trainable_variables() # defaults to the default graph 
    

    如果恢復後,您需要訪問比你訓練的變量等中介結果,我發現它將它們放入定製集合中非常方便,如下所示:

    ... 
    input_ = tf.placeholder(tf.float32, shape=[64, 64]) 
    .... 
    tf.add_to_collection('my_custom_collection', input_) 
    

    此集合會自動存儲(除非您特別指定不要通過在的參數collection_list中省略此集合的名稱)。所以,你可以恢復後只檢索input_佔位符如下:

    ... 
    with tf.Session() as sess: 
        saver.restore(sess, latest_checkpoint) 
        input_ = tf.get_collection_ref('my_custom_collection')[0] 
    

  • 的Tensorflow文檔說:「如果沒有指定collection_list,所有集合在模型將被導出「這是否意味着如果我在collection_list中沒有指定變量,那麼模型中的所有變量都會導出,因爲它們在默認集合中?
  • 是的。再次注意collection_list是一個收集而不是變量列表的微妙細節。事實上,如果您只希望保存某些變量,則可以在構造對象時指定這些變量。從tf.train.Saver.__init__的文檔:

    """Creates a `Saver`. 
    
        The constructor adds ops to save and restore variables. 
    
        `var_list` specifies the variables that will be saved and restored. It can 
        be passed as a `dict` or a list: 
    
        * A `dict` of names to variables: The keys are the names that will be 
         used to save or restore the variables in the checkpoint files. 
        * A list of variables: The variables will be keyed with their op name in 
         the checkpoint files. 
    

  • 的Tensorflow文檔說:「爲了使Python對象被序列化和從MetaGraphDef,所述Python類必須實現 to_proto()from_proto()方法,以及使用register_proto_function與系統 註冊。「這是否意味着to_proto()from_proto()必須阿迪僅限於我已定義的類和 想要導出?如果我只使用標準的Python數據類型(int, float,list,dict),那麼這是不是相關的?
  • 我從來沒有使用這個功能,但我會說你的解釋是正確的。