2016-04-28 82 views
2

我有以下情況:TensorFlow:按名稱訪問變量的內容

我已經建立,訓練並保存了我的網絡。現在,我正試圖恢復網絡並將權重矩陣可視化。

我知道變量的所有名稱,但我沒有分配給變量的python標記傳遞給會話進行評估。我如何檢索變量中的數據?

這裏是我的代碼的情況:

dataset_params = nn_params.mnist_dataset_params 
design = nn_designs.mnist_net_A_design 
## Build Housing Object 
mnist_nn = nn_class.CNN(**dataset_params) 
mnist_nn.build_net(design['design']) 
mnist_nn.__setattr__('saved_path',saved_model) 
mnist_nn_epoch_file = saved_model+'_epochs_completed.txt' 
mnist_nn.__setattr__('epoch_file',mnist_nn_epoch_file) 


# evaluate weight variables 
session = tf.Session() 
saver = tf.train.Saver() 
session.run(tf.initialize_all_variables()) 
saver.restore(session,saved_model) 




session.close() 

我要傳遞給會話,以拉出的權重是什麼? (體重名稱示例是:'conv_w_1')?

回答

6

您可以在此使用tf.get_collection()查找方法來獲取所需的變量做:

weight_var = tf.get_collection(tf.GraphKeys.VARIABLES, "conv_w_1")[0] 

weight_var_value = session.run(weight_var) 
0

或者你可以通過使用功能tf.get_default_graph()的結果get_tensor_by_name:

valua_of_conv_w_1 = session.run(tf.get_default_graph().get_tensor_by_name("conv_w_1:0"))