1
我understand that there are advantages(特別是當我擴大我建立的模型範圍和他們的工作數據集的大小)到使用TensorFlow的新Dataset
作爲我的數據饋送管道的習慣用法。但是,我無法將現有的基於feed_dict
的代碼映射到此新模型。如何將基於Feed的TensorFlow基本代碼轉換爲使用「數據集」?
我面臨的一個問題是,我無法理清批處理和時代如何交互,或者如何與我經常做的日誌記錄和驗證交互。
例如,下圖中的內容如何使用Dataset
?如果需要的話
# Load and process data into tensors of dimension (N, C_i) for input and (N, C_o) for output
# where N is the number of examples and C_ is the number of chanels, and the values are activations
train_x, train_y, valid_x, valid_y = load_data(file, [segments], ...)
train_size = len(train_x)
train_stats_feed = {input_activation: train_x, correct_output: train_y, is_train: False}
valid_stats_feed = {input_activation: valid_x, correct_output: valid_y, is_train: False}
with tf.Session(config=tf.ConfigProto(...)) as sess:
sess.run(tf.initialize_all_variables())
# Some analysis; not always done but the code needs to support it
train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), 0)
test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), 0)
test_writer.add_summary(sess.run(gs_summary), 0)
print(log_fmt.format(0, float(sess.run(accuracy, feed_dict=valid_stats_feed)),
float(sess.run(loss, feed_dict=valid_stats_feed))))
for ep in range(epochs):
# Slice the training data into random batches
batch_indices = np.array_split(np.random.permutation(train_size), int(train_size/mb_size))
for mini_batch_indices in batch_indices:
sess.run(train_step, feed_dict={input_activation: train_x[mini_batch_indices],
correct_output: train_y[mini_batch_indices], is_train: True})
gs = int(sess.run(global_step))
if gs % log_steps == 0:
test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), gs)
train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), gs)
acc = float(sess.run(accuracy, feed_dict=valid_stats_feed))
sess.run(validation_accuracy.assign(acc))
print(log_fmt.format(gs, acc, float(sess.run(loss, feed_dict=valid_stats_feed))))
print(ep_fmt.format(ep + 2))
test_writer.add_summary(sess.run(gs_summary), ep + 1)
一些針對上述不太明顯的定義,:
# Preliminaries
# Some basic preliminaries, the details of which are not important to the question
# Mostly pretty standard; obvious things omitted from MWE for brevity
global_step = tf.Variable(0, trainable=False, name='global_step')
validation_accuracy = tf.Variable(0.0, trainable=False, name='validation_accuracy', dtype=tf.float32)
is_train = tf.placeholder(tf.bool, [], name='is_train')
input_activation = tf.placeholder(tf.float32, shape=[None, in_nodes], name='inputs')
correct_output = tf.placeholder(tf.float32, shape=[None, out_nodes], name='correct_outputs')
network_output = tf.identity(out_activations)
correct_predictions = correct_fn(correct_output, network_output)
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
error = cost_fn(correct_output, network_output)
loss = error + FLAGS.regularization_weight * sum(tf.nn.l2_loss(w) for w in layer_weights)
train_step = tf.train.MomentumOptimizer(learning_rate, momentum=momentum).minimize(loss, global_step=global_step)
# Logging
train_writer = tf.summary.FileWriter(trainlogfile, tf.get_default_graph())
test_writer = tf.summary.FileWriter(testlogfile, tf.get_default_graph())
gs_summary = tf.summary.scalar('global_step_at_epoch', global_step)
merged = tf.summary.merge_all()
這不是我清楚如何匹配了與上述(例如,如何在訓練和驗證交織,其中從文件加載時等) – orome
這也產生各種錯誤。 – orome
我想你在提問之前需要有一個基本的理解。 –