2017-04-02 50 views
1

我正在使用TF.LEARN和mnist數據。我以0.96的準確度訓練了我的神經網絡,但現在我不太確定如何預測一個值。獲取有關預測MNIST數據集的奇怪值

這裏是我的代碼..

#getting mnist data to a zip in the computer. 
mnist.SOURCE_URL = 'https://web.archive.org/web/20160117040036/http://yann.lecun.com/exdb/mnist/' 
trainX, trainY, testX, testY = mnist.load_data(one_hot=True) 


# Define the neural network 
def build_model(): 
    # This resets all parameters and variables 
    tf.reset_default_graph() 
    net = tflearn.input_data([None, 784]) 
    net = tflearn.fully_connected(net, 100, activation='ReLU') 
    net = tflearn.fully_connected(net, 10, activation='softmax') 
    net = tflearn.regression(net, optimizer='sgd', learning_rate=0.1, loss='categorical_crossentropy') 
    # This model assumes that your network is named "net"  
    model = tflearn.DNN(net) 
    return model 

# Build the model 
model = build_model() 

model.fit(trainX, trainY, validation_set=0.1, show_metric=True, batch_size=100, n_epoch=8) 

#Here is the problem 
#lets say I want to predict what my neural network will reply back if I put back the send value from my trainX 

the value of trainX[2] is 4

pred = model.predict([trainX[2]]) 
print(pred) 
#What I get is 
[[2.6109733880730346e-05, 4.549271125142695e-06, 1.8098366126650944e-05, 0.003199575003236532, 0.20630565285682678, 0.0003870908112730831, 4.902480941382237e-05, 0.006617342587560415, 0.018498118966817856, 0.764894425868988]] 

我要的是 - > 4

的問題是,我不知道如何使用此功能預測並放入trainX值以獲得預測。

回答

0

張量流的預測給你一個概率輸出。從pred獲得最大概率的標籤就足以獲得網絡的成熟度。

pred = np.argmax(pred, axis=1) 

在這種情況下是不是4,但9

np是numpy的模塊導入爲import numpy as np,但隨時與tf.argmax(pred, 1)來取代它使用tensorflow的argmax來代替。

+0

我做這個 預解碼= model.predict([trainX [5]) 打印(np.argmax(預解碼)) 得到了答案,但謝謝你告訴我關於tf.argmax(預解碼,1) –

+0

我想我不清楚我的問題,我只是想知道如何計算索引號,這基本上是通過使用np.argmax ...對不起,混亂。我非常感謝答案! –

+0

@MasnadNihit如果你只有一個預測,那麼'np.argmax'適合你。如果你有多個,那麼你需要'np.argmax(pred,1)'來同時獲得所有預測的所有索引。 –

0

你得到一個9,這是相當類似於4

什麼model.predict回報是得分,雖然結果數組中的第五值(第五值是4,因爲它開始一個零)得到一個相對較高的分數(0.26秒高) - 你的模型給最後一位數字(9)最高分數0.76。這隻意味着你的分類器在這裏有點不對勁 - 所以你應該考慮使用不同的分類器或者使用超參數。

+0

謝謝你benieve!我找出了問題所在。這裏是我的問題的最終解決方案 pred = model.predict([trainX [5]]) print(np.argmax(pred)) –

+0

@benieve,我想我不清楚我的問題,我只是想知道如何計算索引號,這基本上是通過使用np.argmax ...對不起,造成混亂。我非常感謝答案! –