2016-12-08 98 views
3

我使用tensorflow工具編寫神經網絡。 一切工作,現在我想出口我的神經網絡的最終權重,使一個單一的預測方法。 我該怎麼做?使用tensorflow導出神經網絡的權重

+0

https://nathanbrixius.wordpress.com/2016/05/24/checkpointing-and-reusing-tensorflow-models/ – martianwars

回答

2

您將需要在培訓結束時使用tf.train.Saver類保存模型。

在初始化Saver對象時,您需要傳遞一個您希望保存的所有變量的列表。最好的部分是你可以在不同的計算圖中使用這些保存的變量!

通過創建一個Saver對象,

# Assume you want to save 2 variables `v1` and `v2` 
saver = tf.train.Saver([v1, v2]) 

使用tf.Session對象保存您的變量,

saver.save(sess, 'filename'); 

當然,你可以添加像global_step其他詳細信息。

您可以使用restore()函數在未來恢復變量。恢復的變量將自動初始化爲這些值。

+1

是否可以獲取參數的原始數據?我想在另一個平臺上運行tensorflow-trained-model,我該怎麼做? –

+2

你可以使用'sess.run(權重)'得到權重的最終值,並將它們導出爲一個numpy數組,例如 – martianwars

+0

這就是我所需要的。另一個問題:我在網絡中使用了'tf.nn.rnn_cell.LSTMCell',我如何訪問'LSTMCell'對象的權重/偏差? –

0

上面的答案是保存/恢復會話快照的標準方法。但是,如果您想將網絡導出爲單個二進制文件以供其他tensorflow工具進一步使用,則需要執行更多步驟。

首先,freeze the graph。 TF提供相應的工具。我用這樣的:

#!/bin/bash -x 

# The script combines graph definition and trained weights into 
# a single binary protobuf with constant holders for the weights. 
# The resulting graph is suitable for the processing with other tools. 


TF_HOME=~/tensorflow/ 

if [ $# -lt 4 ]; then 
    echo "Usage: $0 graph_def snapshot output_nodes output.pb" 
    exit 0 
fi 

proto=$1 
snapshot=$2 
out_nodes=$3 
out=$4 

$TF_HOME/bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=$proto \ 
    --input_checkpoint=$snapshot \ 
    --output_graph=$out \ 
    --output_node_names=$out_nodes 

做完,你可以optimize it for inference,或使用any other tool