2017-02-24 67 views
0

我對Tensorflow非常非常全新,需要編寫一個腳本來測試從檢查點文件恢復的模型上的單個示例。測試恢復的張量流模型的一般方法

我想知道是否有一種通用的方法來爲恢復的模型構建測試函數,而無需知道模型的所有細節。

此外,在下面的代碼的最後一部分,這看起來像我朝着正確的方向?如果是這樣,那麼如何在不瞭解模型細節的情況下構建「y」?

import tensorflow as tf 
from tensorflow.python import pywrap_tensorflow 
import numpy as np 
from fuel.datasets.hdf5 import H5PYDataset 

ckpt_path='ckt/mnist/mnist_2017_02_23_17_22_50/mnist_2017_02_23_17_22_50_5000.ckpt' 

############################## 
#### Initialize Variables #### 
############################## 

reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path) 
var_to_shape_map = reader.get_variable_to_shape_map() 
var=[0]*len(var_to_shape_map) 
i=0 
for key in var_to_shape_map: 
    var[i] = tf.Variable(reader.get_tensor(key), name=key) 
    #print("tensor_name: ", key) 
    #print(reader.get_tensor(key)) 
    i=i+1 
initialize=tf.global_variables_initializer() 

############################### 
####### Restore Model ######### 
############################### 

saver = tf.train.Saver() 
sess = tf.Session() 
saver.restore(sess, ckpt_path) 

############################### 
##### Get Example to Test ##### 
############################### 

test_set = H5PYDataset('../CNN3D/data/bmnist.hdf5', which_sets=('test',)) 
handle = test_set.open() 
for i in range(0,100): 
    test_data = test_set.get_data(handle, slice(i, i+1)) 
    if test_data[1][0][0]==8: 
     model_idx=i 
test_data = test_set.get_data(handle, slice(model_idx,model_idx+1)) 
data = tf.Variable(np.asarray(test_data[0][0][0]), name='data') 

############################### 
######## Test Example ######### 
############################### 

x = tf.placeholder(tf.float32,shape=[28,28]) 
y = ??? 
sess.run(initialize) 
result=sess.run(y, feed_dict={x: data}) 
print result 

回答

0

Estimator類有一個方便的實用程序,從而,如果模型纏的估計,裝載並從中預測是很容易。總體而言,如果沒有某種協調,這將很難。