0
我試圖從保存的模型output_graph.pb
中提取所有權重/偏差。output_graph.pb上的tf.GraphKeys.TRAINABLE_VARIABLES導致空列表
我讀的模式:
def create_graph(modelFullPath):
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
GRAPH_DIR = r'C:\tmp\output_graph.pb'
create_graph(GRAPH_DIR)
,並試圖此希望我能提取所有的權重/偏置 每一層內。
with tf.Session() as sess:
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print (len(all_vars))
但是,我得到的值爲len。
最終目標是提取權重和偏差並將其保存到文本文件/ np.arrays。