有點晚了,但我仍會試着回答。
- 是
tf.train.export_meta_graph
和tf.train.import_meta_graph
正確的工具呢?
我會這麼說。請注意,當您通過tf.train.Saver
保存模型時,將隱式調用tf.train.export_meta_graph
。要點是:
# create the model
...
saver = tf.train.Saver()
with tf.Session() as sess:
...
# save graph and variables
# if you are using global_step, the saver will automatically keep the n=5 latest checkpoints
saver.save(sess, save_path, global_step)
然後恢復:
save_path = ...
latest_checkpoint = tf.train.latest_checkpoint(save_path)
saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
注意,不是調用tf.train.import_meta_graph
你也可以叫你用來創建擺在首位模型的原始代碼段。但是,我認爲使用import_meta_graph
這種方式更優雅,即使您無法訪問創建它的代碼,也可以恢復模型。
- 我需要包括在
collection_list
,我想包含在快照中所有變量的名字嗎? (最簡單的對我來說將是包括一切。)
號但是問題是有點混亂:在export_meta_graph
的collection_list
並不意味着是一個變量列表,但集合(即列表字符串鍵)。
集合非常方便,例如,所有訓練的變量都通過調用自動包含在集合tf.GraphKeys.TRAINABLE_VARIABLES
,你可以得到:
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
或
tf.trainable_variables() # defaults to the default graph
如果恢復後,您需要訪問比你訓練的變量等中介結果,我發現它將它們放入定製集合中非常方便,如下所示:
...
input_ = tf.placeholder(tf.float32, shape=[64, 64])
....
tf.add_to_collection('my_custom_collection', input_)
此集合會自動存儲(除非您特別指定不要通過在的參數collection_list
中省略此集合的名稱)。所以,你可以恢復後只檢索input_
佔位符如下:
...
with tf.Session() as sess:
saver.restore(sess, latest_checkpoint)
input_ = tf.get_collection_ref('my_custom_collection')[0]
- 的Tensorflow文檔說:「如果沒有指定
collection_list
,所有集合在模型將被導出「這是否意味着如果我在collection_list
中沒有指定變量,那麼模型中的所有變量都會導出,因爲它們在默認集合中?
是的。再次注意collection_list
是一個收集而不是變量列表的微妙細節。事實上,如果您只希望保存某些變量,則可以在構造對象時指定這些變量。從tf.train.Saver.__init__
的文檔:
"""Creates a `Saver`.
The constructor adds ops to save and restore variables.
`var_list` specifies the variables that will be saved and restored. It can
be passed as a `dict` or a list:
* A `dict` of names to variables: The keys are the names that will be
used to save or restore the variables in the checkpoint files.
* A list of variables: The variables will be keyed with their op name in
the checkpoint files.
- 的Tensorflow文檔說:「爲了使Python對象被序列化和從MetaGraphDef,所述Python類必須實現
to_proto()
和from_proto()
方法,以及使用register_proto_function與系統 註冊。「這是否意味着to_proto()
和 from_proto()
必須阿迪僅限於我已定義的類和 想要導出?如果我只使用標準的Python數據類型(int, float,list,dict),那麼這是不是相關的?
我從來沒有使用這個功能,但我會說你的解釋是正確的。