我在tensorflow中編寫了一個定製的內核操作來讀取csv格式的數據。爲什麼自定義讀操作只能在test_session上運行
它在TestCase中正常工作,sess
對象返回test_session()
函數。
當我轉到正常代碼時,讀取器每次都返回相同的結果。然後,我在MyOp:Compute
函數的開頭部分進行了一些調試打印。看起來在第一次運行後,sess.run(myop)
根本不會調用MyOp:Compute
函數。
然後我回到我的測試情況下,如果我更換一個tf.Session()
代替self.test_session()
會話對象,它沒有以同樣的方式。
任何人都有這個想法嗎?
分享更多的細節,這是我小的演示代碼: https://github.com/littleDing/mini_csv_reader
測試用例:
def testSimple(self):
input_data_schema, feas, batch_size = self.get_simple_format()
iter_op = ops.csv_iter('./sample_data.txt', input_data_schema, feas, batch_size=batch_size, label='label2')
with self.test_session() as sess:
label,sign = sess.run(iter_op)
print label
self.assertAllEqual(label.shape, [batch_size])
self.assertAllEqual(sign.shape, [batch_size, len(feas)])
self.assertAllEqual(sum(label), 2)
self.assertAllEqual(sign[0,:], [7,0,4,1,1,1,5,9,8])
label,sign = sess.run(iter_op)
self.assertAllEqual(label.shape, [batch_size])
self.assertAllEqual(sign.shape, [batch_size, len(feas)])
self.assertAllEqual(sum(label), 1)
self.assertAllEqual(sign[0,:], [9,9,3,1,1,1,5,4,8])
正常的呼叫:
def testing_tf():
path = './sample_data.txt'
input_data_schema, feas, batch_size = get_simple_format()
with tf.device('/cpu:0'):
n_data_op = tf.placeholder(dtype=tf.float32)
iter_op = ops.csv_iter(path, input_data_schema, feas, batch_size=batch_size, label='label2')
init_op = [tf.global_variables_initializer(), tf.local_variables_initializer() ]
with tf.Session() as sess:
sess.run(init_op)
n_data = 0
for batch_idx in range(3):
print '>>>>>>>>>>>>>> before run batch', batch_idx
## it should be some debug printing here, but nothing come out when batch_idx>0
label,sign = sess.run(iter_op)
print '>>>>>>>>>>>>>> after run batch', batch_idx
## the content of sign remain the same every time
print sign
if len(label) == 0:
break
你能分享特定的代碼來理解更加清晰 –
請儘量分擔問題的最小工作示例。沒有能夠說明您的問題的代碼,我們不可能提供幫助。 – Engineero
@Engineero我已更新最低工作代碼,您想要檢查嗎? –