2017-10-18 29 views
0

我在tensorflow中編寫了一個定製的內核操作來讀取csv格式的數據。爲什麼自定義讀操作只能在test_session上運行

它在TestCase中正常工作,sess對象返回test_session()函數。

當我轉到正常代碼時,讀取器每次都返回相同的結果。然後,我在MyOp:Compute函數的開頭部分進行了一些調試打印。看起來在第一次運行後,sess.run(myop)根本不會調用MyOp:Compute函數。

然後我回到我的測試情況下,如果我更換一個tf.Session()代替self.test_session()會話對象,它沒有以同樣的方式。

任何人都有這個想法嗎?

分享更多的細節,這是我小的演示代碼: https://github.com/littleDing/mini_csv_reader

測試用例:

def testSimple(self): 
    input_data_schema, feas, batch_size = self.get_simple_format() 
    iter_op = ops.csv_iter('./sample_data.txt', input_data_schema, feas, batch_size=batch_size, label='label2') 
    with self.test_session() as sess: 
    label,sign = sess.run(iter_op) 
    print label 

    self.assertAllEqual(label.shape, [batch_size]) 
    self.assertAllEqual(sign.shape, [batch_size, len(feas)]) 
    self.assertAllEqual(sum(label), 2) 
    self.assertAllEqual(sign[0,:], [7,0,4,1,1,1,5,9,8]) 

    label,sign = sess.run(iter_op) 
    self.assertAllEqual(label.shape, [batch_size]) 
    self.assertAllEqual(sign.shape, [batch_size, len(feas)]) 
    self.assertAllEqual(sum(label), 1) 
    self.assertAllEqual(sign[0,:], [9,9,3,1,1,1,5,4,8]) 

正常的呼叫:

def testing_tf(): 
    path = './sample_data.txt' 
    input_data_schema, feas, batch_size = get_simple_format() 
    with tf.device('/cpu:0'): 
     n_data_op = tf.placeholder(dtype=tf.float32) 
     iter_op = ops.csv_iter(path, input_data_schema, feas, batch_size=batch_size, label='label2') 
     init_op = [tf.global_variables_initializer(), tf.local_variables_initializer() ] 

    with tf.Session() as sess: 
     sess.run(init_op) 
     n_data = 0 
     for batch_idx in range(3): 
     print '>>>>>>>>>>>>>> before run batch', batch_idx 
     ## it should be some debug printing here, but nothing come out when batch_idx>0 
     label,sign = sess.run(iter_op) 
     print '>>>>>>>>>>>>>> after run batch', batch_idx 
     ## the content of sign remain the same every time 
     print sign 
     if len(label) == 0: 
      break 
+1

你能分享特定的代碼來理解更加清晰 –

+0

請儘量分擔問題的最小工作示例。沒有能夠說明您的問題的代碼,我們不可能提供幫助。 – Engineero

+0

@Engineero我已更新最低工作代碼,您想要檢查嗎? –

回答

1

一看的implementationtf.test.TestCase.test_session()提供了一些線索,因爲它配置會話的方式與直接調用的方式有所不同。特別是,test_session()disables不斷摺疊優化。默認情況下,TensorFlow會將圖的無狀態部分轉換爲tf.constant()節點,因爲每次運行它們時都會產生相同的結果。

在你"CsvIter" OP的註冊,對SetIsStateful()註釋,所以TensorFlow會將其視爲無國籍,因此受到恆定的摺疊。然而,它的實現是非常有狀態的:一般來說,任何你希望用相同的輸入張量產生不同結果的操作,或者任何在成員變量中存儲可變狀態的操作符應該被標記爲有狀態。

的解決方案是一個修改一行到REGISTER_OP"CsvIter"

REGISTER_OP("CsvIter") 
    .Input("data_file: string") 
    .Output("labels: float32") 
    .Output("signs: int64") 
    .Attr("input_schema: list(string)") 
    .Attr("feas: list(string)") 
    .Attr("label: string = 'label' ") 
    .Attr("batch_size: int = 10000") 
    .SetIsStateful(); // Add this line. 
+0

解決問題,非常感謝! –

相關問題