2017-10-11 20 views
0

我試圖從保存的模型output_graph.pb中提取所有權重/偏差。output_graph.pb上的tf.GraphKeys.TRAINABLE_VARIABLES導致空列表

我讀的模式:

def create_graph(modelFullPath): 
    """Creates a graph from saved GraphDef file and returns a saver.""" 
    # Creates graph from saved graph_def.pb. 
    with tf.gfile.FastGFile(modelFullPath, 'rb') as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 
     tf.import_graph_def(graph_def, name='') 

GRAPH_DIR = r'C:\tmp\output_graph.pb' 
create_graph(GRAPH_DIR) 

,並試圖此希望我能提取所有的權重/偏置 每一層內。

with tf.Session() as sess: 
    all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 
    print (len(all_vars)) 

但是,我得到的值爲len。

最終目標是提取權重和偏差並將其保存到文本文件/ np.arrays。

回答

1

tf.import_graph_def()函數沒有足夠的信息來重建tf.GraphKeys.TRAINABLE_VARIABLES集合(爲此,您需要一個MetaGraphDef)。但是,如果output.pb包含「凍結」GraphDef,則所有權重將存儲在圖中的tf.constant()節點中。提取它們,你可以這樣做以下:

create_graph(GRAPH_DIR) 

constant_values = {} 

with tf.Session() as sess: 
    constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"] 
    for constant_op in constant_ops: 
    constant_values[constant_op.name] = sess.run(constant_op.outputs[0]) 

注意constant_values可能會包含更多的價值不僅僅是權重,所以你可能需要通過op.name或一些其他標準進一步篩選。

相關問題