2017-08-25 29 views
2

的值當運行MNIST數據集,我想知道究竟是什麼訓練batch.Here在我的模型的輸出結果是我的代碼:(我還沒有添加優化和損失函數):無法取得張

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 

INPUT_NODE = 784 # the total pixels of the input images 
OUTPUT_NODE = 10 # the output varies from 0 to 9 
LAYER_NODE = 500 
BATCH_SIZE = 100 
TRAINING_STEPS = 10 

def inference(input_tensor, avg_class, weight1, biase1, weight2, biase2): 
    if avg_class == None: 
     layer = tf.nn.relu(tf.matmul(input_tensor, weight1) + biase1) 
     return tf.matmul(layer, weight2)+biase2 
    else: 
     layer = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weight1)) + 
       avg_class.average(biase1)) 
     return tf.matmul(layer, avg_class.average(weight2)) + avg_class.average(biase2) 


def train(mnist): 
    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name = 'x-input') 
    y = tf.placeholder(tf.float32, [None, OUTPUT_NODE],name = 'y-input') 

    weight1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER_NODE], stddev = 0.1)) 
    biase1 = tf.Variable(tf.constant(0.1, shape = [LAYER_NODE])) 
    weight2 = tf.Variable(tf.truncated_normal([LAYER_NODE, OUTPUT_NODE], stddev = 0.1)) 
    biase2 = tf.Variable(tf.constant(0.1, shape = [OUTPUT_NODE])) 

    out = inference(x, None, weight1, biase1, weight2, biase2) 

    with tf.Session() as sess: 
     tf.global_variables_initializer().run() 
     validate_feed = {x:mnist.validation.images, y:mnist.validation.labels} 
     test_feed = {x:mnist.test.images, y:mnist.test.labels} 

     for i in range(TRAINING_STEPS): 

      xs, ys = mnist.train.next_batch(BATCH_SIZE) 
      sess.run(out, feed_dict= {x:xs, y:ys}) 
      print(out) 

def main(arg = None): 
    mnist = input_data.read_data_sets("/home/vincent/Tensorflow/MNIST/data/", one_hot = True) 
    train(mnist) 

if __name__ == '__main__': 
    tf.app.run() 

我嘗試打印出來:

張量( 「add_1:0」,形狀=(?10),D類= FLOAT32)

如果我想知道out的價值,我該怎麼做? 我試圖print(out.eval()),它提出了錯誤

回答

2

out是張量對象。如果你想獲得它的價值,替代

sess.run(out, feed_dict= {x:xs, y:ys}) 
print(out) 

res_out=sess.run(out, feed_dict= {x:xs, y:ys}) 
print(res_out)