2017-10-14 150 views
2

我堅持着修復與Tensorflow預訓練的網絡....無法恢復預先訓練網絡Tensorflow

import tensorflow as tf 
import os 
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 

sess=tf.Session() 
saver = tf.train.import_meta_graph('./model/20170512-110547/model-20170512-110547.meta') 
saver.restore(sess,'./model/20170512-110547/') 

我想用這是訓練前訓練的網絡人臉識別,然後想添加一些圖層進行轉移學習。 (我從這裏下載的模型。https://github.com/davidsandberg/facenet

當我執行上面的代碼,它顯示了錯誤,

WARNING:tensorflow:The saved meta_graph is possibly from an older release: 
'model_variables' collection should be of type 'byte_list', but instead is of type 'node_list'. 
Traceback (most recent call last): 
    File "/Users/user/Desktop/desktop/Python/HCR/Transfer_face/test.py", line 7, in <module> 
    saver.restore(sess,'./model/20170512-110547/') 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1560, in restore 
    {self.saver_def.filename_tensor_name: save_path}) 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 895, in run 
    run_metadata_ptr) 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1124, in _run 
    feed_dict_tensor, options, run_metadata) 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run 
    options, run_metadata) 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1340, in _do_call 
    raise type(e)(node_def, op, message) 
tensorflow.python.framework.errors_impl.NotFoundError: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for ./model/20170512-110547/ 
    [[Node: save/RestoreV2_491 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_491/tensor_names, save/RestoreV2_491/shape_and_slices)]] 

Caused by op u'save/RestoreV2_491', defined at: 
    File "/Users/user/Desktop/desktop/Python/HCR/Transfer_face/test.py", line 6, in <module> 
    saver = tf.train.import_meta_graph('./model/20170512-110547/model-20170512-110547.meta') 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1698, in import_meta_graph 
    **kwargs) 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py", line 656, in import_scoped_meta_graph 
    producer_op_list=producer_op_list) 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 313, in import_graph_def 
    op_def=op_def) 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2630, in create_op 
    original_op=self._default_original_op, op_def=op_def) 
    File "/Users/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1204, in __init__ 
    self._traceback = self._graph._extract_stack() # pylint: disable=protected-access 

NotFoundError (see above for traceback): Unsuccessful TensorSliceReader constructor: Failed to find any matching files for ./model/20170512-110547/ 
    [[Node: save/RestoreV2_491 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_491/tensor_names, save/RestoreV2_491/shape_and_slices)]] 

我不明白,爲什麼系統無法找到預先訓練數據... 並且目錄結構如下

USER-NO-的MacBook-PRO:Transfer_face用戶$ LS -R

模型test.py

./model:

20170512-110547

./model/20170512-110547:

20170512-110547.pb

模型20170512-110547.ckpt-250000.index

模型20170512-110547.ckpt-250000.data-00000-的-00001

模型20170512-110547.meta

+0

嘗試使用tensorflow的舊版本:'保存的meta_graph可能來自舊版本'。該模型使用r0.12 – Maxim

+0

構建,謝謝。我試過版本0.12和1.2.0(這是寫在要求)。但仍然顯示相同的錯誤.... –

+0

當你調用'saver.restore()'(而不是相對路徑''。/ model/20170512-110547 /''''')時,嘗試傳遞完整的絕對路徑到模型目錄。 。舊版本的TensorFlow(包括0.12,我認爲)有一個錯誤,他們不接受某些API中的相對路徑,但應該在最新版本中修復這個錯誤。 – mrry

回答

2

導入.pb文件。

import tensorflow as tf 
from tensorflow.python.framework import tensor_util 

with tf.gfile.GFile('20170512-110547.pb', "rb") as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 

#import into default graph 
tf.import_graph_def(graph_def) 

#print some data 
wts = [n for n in graph_def.node if n.op == 'Const'] 

for n in wts: 
    print(tensor_util.MakeNdarray(n.attr['value'].tensor)) 

鏈接問題:

Import a simple Tensorflow frozen_model.pb file and make prediction in C++

get the value weights from .pb file by Tensorflow

相關文檔:GraphDef

0

您需要使用CKPT路徑」 ./model/20170512-110547/model-20170512 -110547.ckpt-250000「而不是文件夾路徑。

相關問題