1
我正在嘗試爲MNIST數據創建一個簡單的線性分類器,我無法讓自己的損失下降。可能是什麼問題呢? 這裏是我的代碼:Tensorflow線性分類器未訓練
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
class LinearClassifier(object):
def __init__(self):
print("LinearClassifier loading MNIST")
self._mnist = input_data.read_data_sets("mnist_data/", one_hot = True)
self._buildGraph()
def _buildGraph(self):
self._tf_TrainX = tf.placeholder(tf.float32, [None, self._mnist.train.images.shape[1]])
self._tf_TrainY = tf.placeholder(tf.float32, [None, self._mnist.train.labels.shape[1]])
self._tf_Weights = tf.Variable(tf.random_normal([784,10]), tf.float32)
self._tf_Bias = tf.Variable(tf.zeros([10]), tf.float32)
self._tf_Y = tf.nn.softmax(tf.matmul(self._tf_TrainX, self._tf_Weights) + self._tf_Bias)
self._tf_Loss = tf.reduce_mean(-tf.reduce_sum(self._tf_TrainY * tf.log(self._tf_Y), reduction_indices=[1]))
self._tf_TrainStep = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(self._tf_Loss)
self._tf_CorrectGuess = tf.equal(tf.argmax(self._tf_Y, 1), tf.arg_max(self._tf_TrainY, 1))
self._tf_Accuracy = tf.reduce_mean(tf.cast(self._tf_CorrectGuess, tf.float32))
self._tf_Initializers = tf.global_variables_initializer()
def train(self, epochs, batch_size):
self._sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
self._sess.run(self._tf_Initializers)
for i in range(epochs):
batchX, batchY = self._mnist.train.next_batch(batch_size)
self._loss, _, self._accurracy = self._sess.run([self._tf_Loss, self._tf_TrainStep, self._tf_Accuracy], feed_dict ={self._tf_TrainX: batchX, self._tf_TrainY: batchY})
print("Epoch: {0}, Loss: {1}, Accuracy: {2}".format(i, self._loss, self._accurracy))
當我通過運行這個:
lc = LinearClassifier()
lc.train(1000, 100)
...我GETT是這樣的:
Epoch: 969, Loss: 8.19491195678711, Accuracy: 0.17999999225139618
Epoch: 970, Loss: 9.09421157836914, Accuracy: 0.1899999976158142
....
Epoch: 998, Loss: 7.865959167480469, Accuracy: 0.17000000178813934
Epoch: 999, Loss: 9.281349182128906, Accuracy: 0.10999999940395355
可能是什麼原因TF .train.GradientDescentOptimizer沒有正確地訓練我的權重和偏差?
謝謝。你是對的。 –