一個可能的答案是使用make_template
這在https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/kernel_tests/template_test.py;它基本上說可以做到這一點:
training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
tf.set_random_seed(1234)
def test_line(x):
m = tf.get_variable("w", shape=[],
initializer=tf.truncated_normal_initializer())
b = tf.get_variable("b", shape=[],
initializer=tf.truncated_normal_initializer())
return x * m + b
line_template = template.make_template("line", test_line)
train_prediction = line_template(training_input)
test_prediction = line_template(test_input)
train_loss = tf.reduce_mean(tf.square(train_prediction - training_output))
test_loss = tf.reduce_mean(tf.square(test_prediction - test_output))
optimizer = tf.train.GradientDescentOptimizer(0.1)
train_op = optimizer.minimize(train_loss)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
initial_test_loss = sess.run(test_loss)
sess.run(train_op)
final_test_loss = sess.run(test_loss)
# Parameters are tied, so the loss should have gone down when we trained it.
self.assertLess(final_test_loss, initial_test_loss)
['tf.cond'不是解決方案。](https://groups.google.com/a/tensorflow.org/forum/#!msg/discuss/mLrt5qc9_uU/sGNbC7GpAwAJ) 它通過兩個隊列中的任何一方進行評估。 – mdaoust
'tf.cond'仍然可以使用,但是有黑客入侵。另一種方法是'QueueBase.from_list'。請參閱:https://github.com/tensorflow/tensorflow/issues/2514 - 嘆息.. – TimZaman