2017-05-24 67 views
0

我正在研究GAN,並決定使用HyperGAN來實現我的算法。它是使用TensorFlow的DCGAN封裝。 HyperGAN使用TF的檢查點方法保存輸出。如何獲得張量流模型的輸出值和輸入值?

後來,我試圖使用運行負載的模型:

import tensorflow as tf 
sess=tf.Session()  
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 
sess.run(tf.global_variables_initializer()) 

然而,由於它的一個GAN,它需要一個輸入潛在向量並輸出圖像。這是使用

out_image = sess.run(last_node, feed_dict(input_node: value)) 

完成,但因爲我裝的模式,我不知道是什麼的最後一個節點的名稱是什麼,輸入節點佔位符的名稱是。如何獲取用於創建圖形的名稱?我嘗試使用TensorBoard進行可視化,但該圖很大,因此卡住了。

回答

0

你應該儘量在圖形中打印張量清單:

with tf.Graph().as_default() as graph: 
.... 

count = 0 
for op in graph.get_operations(): 
    print op.values() 
    count+=1 
    if count == 50: 
     assert False 

爲了看圖形的第一個50個節點,你會看到這樣的事情:

(<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,) 
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,) 

我把計數放在那裏,因爲通常終端會輸出很多張數,以至於初始輸入節點名稱在終端中消失。

最後,乾脆註釋掉線計數使用:

#count = 0 
for op in graph.get_operations(): 
    print op.values() 
    #count+=1 
    #if count == 50: 
    # assert False 

拿到打印出來的最後幾個節點(即你的輸出節點)。

相關問題