我正在嘗試使用張量流來訓練使用交通標誌圖像的神經網絡(LeNet)。我想檢查預處理技術對nn性能的影響。因此,我預處理了圖像,並將結果(訓練圖像,驗證圖像,testimages,最終testimages)作爲一個字典存儲。當改變輸入數據時,Tensorflow模型沒有訓練
然後我試圖遍歷該字典,然後使用tensorflow的訓練和驗證操作如下
import tensorflow as tf
from sklearn.utils import shuffle
output_data = []
EPOCHS = 5
BATCH_SIZE = 128
rate = 0.0005
for key in finalInputdata.keys():
for procTypes in range(0,(len(finalInputdata[key]))):
if np.shape(finalInputdata[key][procTypes][0]) !=():
X_train = finalInputdata[key][procTypes][0]
X_valid = finalInputdata[key][procTypes][1]
X_test = finalInputdata[key][procTypes][2]
X_finaltest = finalInputdata[key][procTypes][3]
x = tf.placeholder(tf.float32, (None, 32, 32,np.shape(X_train)[-1]))
y = tf.placeholder(tf.int32, (None))
one_hot_y = tf.one_hot(y,43)
# Tensor Operations
logits = LeNet(x,np.shape(X_train)[-1])
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,one_hot_y)
softmax_probability = tf.nn.softmax(logits)
loss_operation = tf.reduce_mean(cross_entropy)
optimizer = tf.train.AdamOptimizer(learning_rate=rate)
training_operation = optimizer.minimize(loss_operation)
correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(one_hot_y,1))
accuracy_operation = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Pipeline for training and evaluation
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
num_examples = len(X_train)
print("Training on %s images processed as %s" %(key,dict_fornames['proctypes'][procTypes]))
print()
for i in range(EPOCHS):
X_train, y_train = shuffle(X_train, y_train)
for offset in range(0, num_examples, BATCH_SIZE):
end = offset + BATCH_SIZE
batch_x, batch_y = X_train[offset:end], y_train[offset:end]
sess.run(training_operation, feed_dict = {x: batch_x, y: batch_y})
training_accuracy = evaluate(X_train,y_train)
validation_accuracy = evaluate(X_valid, y_valid)
testing_accuracy = evaluate(X_test, y_test)
final_accuracy = evaluate(X_finaltest, y_finalTest)
print("EPOCH {} ...".format(i+1))
print("Training Accuracy = {:.3f}".format(training_accuracy))
print("Validation Accuracy = {:.3f}".format(validation_accuracy))
print()
output_data.append({'EPOCHS':EPOCHS, 'LearningRate':rate, 'ImageType': 'RGB',\
'PreprocType': dict_fornames['proctypes'][0],\
'TrainingAccuracy':training_accuracy, 'ValidationAccuracy':validation_accuracy, \
'TestingAccuracy': testing_accuracy})
sess.close()
的評價功能如下
def evaluate(X_data, y_data):
num_examples = len(X_data)
total_accuracy = 0
sess = tf.get_default_session()
for offset in range(0,num_examples, BATCH_SIZE):
batch_x, batch_y = X_data[offset:offset+BATCH_SIZE], y_data[offset:offset+BATCH_SIZE]
accuracy = sess.run(accuracy_operation, feed_dict = {x:batch_x, y:batch_y})
total_accuracy += (accuracy * len(batch_x))
return total_accuracy/num_examples
一旦我執行程序,它對數據集的第一次迭代有效,但從第二次迭代開始,網絡不會訓練並繼續爲所有其他迭代執行此操作。
Training on RGB images processed as Original
EPOCH 1 ...
Training Accuracy = 0.525
Validation Accuracy = 0.474
EPOCH 2 ...
Training Accuracy = 0.763
Validation Accuracy = 0.682
EPOCH 3 ...
Training Accuracy = 0.844
Validation Accuracy = 0.723
EPOCH 4 ...
Training Accuracy = 0.888
Validation Accuracy = 0.779
EPOCH 5 ...
Training Accuracy = 0.913
Validation Accuracy = 0.795
Training on RGB images processed as Mean Subtracted Data
EPOCH 1 ...
Training Accuracy = 0.056
Validation Accuracy = 0.057
EPOCH 2 ...
Training Accuracy = 0.057
Validation Accuracy = 0.057
EPOCH 3 ...
Training Accuracy = 0.057
Validation Accuracy = 0.056
EPOCH 4 ...
Training Accuracy = 0.058
Validation Accuracy = 0.056
EPOCH 5 ...
Training Accuracy = 0.058
Validation Accuracy = 0.058
Training on RGB images processed as Normalized Data
EPOCH 1 ...
Training Accuracy = 0.058
Validation Accuracy = 0.054
EPOCH 2 ...
Training Accuracy = 0.058
Validation Accuracy = 0.054
EPOCH 3 ...
Training Accuracy = 0.058
Validation Accuracy = 0.054
EPOCH 4 ...
Training Accuracy = 0.058
Validation Accuracy = 0.054
EPOCH 5 ...
Training Accuracy = 0.058
Validation Accuracy = 0.054
但是,如果我重新啓動內核並使用任何數據類型(任何迭代),它都可以工作。我發現我必須清除圖表或者爲多個數據類型運行多個會話,但是我不清楚如何實現這一點。我嘗試使用tf.reset_default_graph()
,但似乎沒有任何效果。有人能指出我正確的方向嗎?
感謝
您能否更具體地瞭解問題所在? – drpng
正如我在標題中已經提到的那樣,網絡不會訓練。即使經過多次迭代,訓練精度仍然保持在0.05 – Mechanic