2017-06-16 14 views
1

當我用下面的代碼當我使用TensorFlow解碼`csv`文件時,如何將'tf.map_fn'應用於SparseTensor?

import tensorflow as tf 

# def input_pipeline(filenames, batch_size): 
#  # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data. 
#  dataset = (tf.contrib.data.TextLineDataset(filenames) 
#    .map(lambda line: tf.decode_csv(
#      line, record_defaults=[['1'], ['1'], ['1']], field_delim='-')) 
#    .shuffle(buffer_size=10) # Equivalent to min_after_dequeue=10. 
#    .batch(batch_size)) 

#  # Return an *initializable* iterator over the dataset, which will allow us to 
#  # re-initialize it at the beginning of each epoch. 
#  return dataset.make_initializable_iterator() 

def decode_func(line): 
    record_defaults = [['1'],['1'],['1']] 
    line = tf.decode_csv(line, record_defaults=record_defaults, field_delim='-') 
    str_to_int = lambda r: tf.string_to_number(r, tf.int32) 
    query = tf.string_split(line[:1], ",").values 
    title = tf.string_split(line[1:2], ",").values 
    query = tf.map_fn(str_to_int, query, dtype=tf.int32) 
    title = tf.map_fn(str_to_int, title, dtype=tf.int32) 
    label = line[2] 
    return query, title, label 

def input_pipeline(filenames, batch_size): 
    # Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data. 
    dataset = tf.contrib.data.TextLineDataset(filenames) 
    dataset = dataset.map(decode_func) 
    dataset = dataset.shuffle(buffer_size=10) # Equivalent to min_after_dequeue=10. 
    dataset = dataset.batch(batch_size) 

    # Return an *initializable* iterator over the dataset, which will allow us to 
    # re-initialize it at the beginning of each epoch. 
    return dataset.make_initializable_iterator() 


filenames=['2.txt'] 
batch_size = 3 
num_epochs = 10 
iterator = input_pipeline(filenames, batch_size) 

# `a1`, `a2`, and `a3` represent the next element to be retrieved from the iterator.  
a1, a2, a3 = iterator.get_next() 

with tf.Session() as sess: 
    for _ in range(num_epochs): 
     print(_) 
     # Resets the iterator at the beginning of an epoch. 
     sess.run(iterator.initializer) 
     try: 
      while True: 
       a, b, c = sess.run([a1, a2, a3]) 
       print(type(a[0]), b, c) 
     except tf.errors.OutOfRangeError: 
      print('stop') 
      # This will be raised when you reach the end of an epoch (i.e. the 
      # iterator has no more elements). 
      pass     

     # Perform any end-of-epoch computation here. 
     print('Done training, epoch reached') 

腳本墜毀沒有返回任何結果,並且達到a, b, c = sess.run([a1, a2, a3])時停下來,但是當我評論

query = tf.map_fn(str_to_int, query, dtype=tf.int32) 
title = tf.map_fn(str_to_int, title, dtype=tf.int32) 

它的工作原理,並返回結果。

2.txt,數據格式是這樣

1,2,3-4,5-0 
1-2,3,4-1 
4,5,6,7,8-9-0 

另外,爲什麼返回的結果是byte-like對象,而不是str

回答

1

我當時一看,它似乎,如果要更換:

query = tf.map_fn(str_to_int, query, dtype=tf.int32) 
title = tf.map_fn(str_to_int, title, dtype=tf.int32) 
label = line[2] 

通過

query = tf.string_to_number(query, out_type=tf.int32) 
title = tf.string_to_number(title, out_type=tf.int32) 
label = tf.string_to_number(line[2], out_type=tf.int32) 

它工作得很好。

看起來有2個嵌套的TensorFlow lambda功能(tf.map_fnDataSet.map)不起作用。幸運的是,它過於複雜。

關於你提到的第二個問題,我得到這個作爲輸出:

[(array([4, 5, 6, 7, 8], dtype=int32), array([9], dtype=int32), 0)] 
<type 'numpy.ndarray'> 
+0

我將修改我的回答你剛纔的問題,以反映這一點。 – npf

+0

非常感謝!它可以工作,但調用'dataset = dataset.batch(batch_size)'時會導致形狀錯誤,並且必須設置'batch_size = 1'。所以我們需要在前一步中填充序列,並且文件將被構造並且可以被容易地解碼。然後關於'tf.string_to_number'的代碼可以被刪除。唉... @尼古拉斯 – danche

相關問題