2016-07-28 35 views
0

我有一個帶有四個輸出標籤的tensorflow程序。我訓練了模型,現在正在用它評估單獨的數據。Tensorflow tf.nn.in_top_k錯誤目標[0]超出範圍

的問題是,在我使用的代碼

import tensorflow as tf 

import main 
import Process 
import Input 

eval_dir = "/Users/Zanhuang/Desktop/NNP/model.ckpt-30" 
checkpoint_dir = "/Users/Zanhuang/Desktop/NNP/checkpoint" 


def evaluate(): 
    with tf.Graph().as_default() as g: 
    images, labels = Process.eval_inputs() 
    forward_propgation_results = Process.forward_propagation(images) 
    init_op = tf.initialize_all_variables() 
    saver = tf.train.Saver() 
    top_k_op = tf.nn.in_top_k(forward_propgation_results, labels, 1) 

    with tf.Session(graph=g) as sess: 
    sess.run(init_op) 
    saver.restore(sess, eval_dir) 
    tf.train.start_queue_runners(sess=sess) 
    print(sess.run(top_k_op)) 

def main(argv=None): 
    evaluate() 

if __name__ == '__main__': 
    tf.app.run() 

總體而言,我只有一個班。

我的錯誤率,在這裏我介紹一個熱矩陣標籤代碼是在這裏:

def error(forward_propagation_results, labels): 
    labels = tf.one_hot(labels, 4) 
    tf.transpose(labels) 
    labels = tf.cast(labels, tf.float32) 
    mean_squared_error = tf.square(tf.sub(labels, forward_propagation_results)) 
    cost = tf.reduce_mean(mean_squared_error) 
    train = tf.train.GradientDescentOptimizer(learning_rate = 0.05).minimize(cost) 
    tf.histogram_summary('accuracy', mean_squared_error) 
    tf.add_to_collection('losses', cost) 

    tf.scalar_summary('LOSS', cost) 

    return train, cost 
+0

這個錯誤表明'forward_propagation_results'(假設是一個大小爲'bxc'的矩陣)和'labels'中的值(假設是一個長度爲「b」的矢量,所有的值都是< C')。你可以嘗試在失敗的行之前添加「print(forward_propagation_result,sess.run(labels))'語句嗎? – mrry

+0

當然:這是我的結果:(,array([40],dtype = int32)) –

回答

0

的問題是在你的labels張無效數據。從your commentlabels張量是包含單個值的向量:[40]。值40大於forward_propagation_result(即4)中的列數。

tf.nn.in_top_k(predictions, targets, k)運算具有以下行爲:

  • 對於每一行predictions[i, :]
    • result[i]爲真,如果predictions[i, targets[i]]是該行中的k個最大的元素中的一個;否則它是錯誤的。

沒有價值predictions[0, 40],因爲(您的評論節目)這樣的說法是1 x 4矩陣。因此TensorFlow會給你一個out of range錯誤。這表明您的評估數據是錯誤的,或者您應該使用不同的評估功能。

+0

你是什麼意思?它的格式與CIFAR 10完全相同。另外,您的意思是不同的評估功能。我只知道k_top –

+0

您正在饋送到網絡的評估示例假設類的標籤爲40.您的網絡爲每個示例生成4個輸出,其中'tf.nn.in_top_k()'解釋爲分數對於標籤0,1,2和3.評估數據是錯誤的(例如,如果您只希望有4個不同的標籤,可能會損壞),或者網絡本身是錯誤的(它應該產生至少41個輸出每個例子使標籤40有效)。 – mrry

+0

我看不到我的代碼在網絡中出錯的地方,但可能是數據。網絡訓練沒有問題,並減少了正確的方法。你說過使用不同的評估函數。你會怎麼做? –