2017-05-13 131 views
1

我已經在MNIST數據集上訓練了具有92%準確性的線性分類器。然後,我固定權重並對輸入圖像進行優化,使得8的softmax概率達到最大。但softmax損失不會降低到2.302(-log(1/10))以下,這意味着我的訓練毫無用處。我究竟做錯了什麼?Tensorflow:顯示MNIST數據集上的線性分類器的訓練權重

代碼訓練的權重:

import tensorflow as tf 
import numpy as np 
from tensorflow.examples.tutorials.mnist import input_data 

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
trX, trY, teX, teY = mnist.train.images, mnist.train.labels,  
mnist.test.images, mnist.test.labels 

X = tf.placeholder("float", [None, 784]) 
Y = tf.placeholder("float", [None, 10]) 

w = tf.Variable(tf.random_normal([784, 10], stddev=0.01)) 
b = tf.Variable(tf.zeros([10])) 

o = tf.nn.sigmoid(tf.matmul(X, w)+b) 

cost= tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=o, labels=Y)) 
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost) 
predict_op = tf.argmax(o, 1) 

sess=tf.Session() 
sess.run(tf.global_variables_initializer()) 
for i in range(100): 
    for start, end in zip(range(0, len(trX), 256), range(256, len(trX)+1, 256)): 
     sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]}) 
    print(i, np.mean(np.argmax(teY, axis=1) == sess.run(predict_op, feed_dict={X: teX}))) 

代碼用於訓練圖像的固定權數:

#Copy trained weights into W,B and pass them as placeholders to new model 
W=sess.run(w) 
B=sess.run(b) 

X=tf.Variable(tf.random_normal([1, 784], stddev=0.01)) 
Y=tf.constant([0, 0, 0, 0, 0, 0, 0, 0, 1, 0]) 

w=tf.placeholder("float") 
b=tf.placeholder("float") 

o = tf.nn.sigmoid(tf.matmul(X, w)+b) 

cost= tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=o, labels=Y)) 
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost) 
predict_op = tf.argmax(o, 1) 

sess.run(tf.global_variables_initializer()) 
for i in range(1000): 
    sess.run(train_op, feed_dict={w:W, b:B}) 
    if i%50==0: 
    sess.run(cost, feed_dict={w:W, b:B}) 
    print(i, sess.run(predict_op, feed_dict={w:W, b:B})) 

回答

1

你的網上的輸出你應該不叫tf.sigmoidsoftmax_cross_entropy_with_logits假定您的輸入是logits,即不受約束的實數。使用

o = tf.matmul(X, w)+b 

將您的準確度提高到92.8%。

通過此次修改,您的第二次培訓有效。成本達到0,雖然由此產生的圖像是什麼,但吸引力。

enter image description here

+0

感謝您的幫助:)。 –