2
我試圖使用shuffle.batch
批處理從.csv文件加載的訓練數據。但是,當我運行代碼時,它似乎不起作用。它沒有顯示任何錯誤,但沒有完成。TensorFlow:shuffle_batch沒有顯示任何錯誤,但沒有完成
那麼,你能告訴我我的代碼有什麼問題嗎?
此外,什麼是適合的容量值和min_after_dequeue
?
import tensorflow as tf
import numpy as np
test_label = []
in_label = []
iris_TRAINING = "iris_training.csv"
iris_TEST = "iris_test.csv"
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=iris_TRAINING, target_dtype=np.int, features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(filename=iris_TEST, target_dtype=np.int, features_dtype=np.float32)
x_train, x_test, y_train, y_test = training_set.data, test_set.data, training_set.target, test_set.target
for n in y_train:
targets = np.zeros(3)
targets[int(n)] = 1 # one-hot pixs[0] is label and then use that number as index of one-hot
in_label.append(targets) #store all of label (one-hot)
training_label = np.asarray(in_label)
for i in y_test:
test_targets = np.zeros(3)
test_targets[int(i)] = 1 # one-hot pixs[0] is label and then use that number as index of one-hot
test_label.append(test_targets)
test_label = np.asarray(test_label)
x = tf.placeholder(tf.float32, [None,4]) #generate placeholder to store value of features for training
W = tf.Variable(tf.zeros([4, 3])) #weight
b = tf.Variable(tf.zeros([3])) #bias
y = tf.matmul(x, W) + b
y_ = tf.placeholder(tf.float32, [None, 3]) #generate placeholder to store value of labels
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession()
# Train
tf.initialize_all_variables().run()
for i in range(5):
batch_xt, batch_yt = tf.train.shuffle_batch([x_train,training_label],batch_size=10,capacity=200,min_after_dequeue=10)
sess.run(train_step, feed_dict={x: batch_xt.eval(), y_: batch_yt.eval()})
print(i)
# Test trained model
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: x_test, y_: test_label}))