2015-11-13 26 views
13

它主要是來自教程的網站上的複製粘貼。我得到一個錯誤:我從TensorFlow的csv閱讀器中遺漏了什麼?

Invalid argument: ConcatOp : Expected concatenating dimensions in the range [0, 0), but got 0 [[Node: concat = Concat[N=4, T=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](concat/concat_dim, DecodeCSV, DecodeCSV:1, DecodeCSV:2, DecodeCSV:3)]]

我的CSV文件的內容是:

3,4,1,8,4

import tensorflow as tf 


filename_queue = tf.train.string_input_producer(["test2.csv"]) 

reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 

# Default values, in case of empty columns. Also specifies the type of the 
# decoded result. 
record_defaults = [[1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults) 
# print tf.shape(col1) 

features = tf.concat(0, [col1, col2, col3, col4]) 
with tf.Session() as sess: 
    # Start populating the filename queue. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(1200): 
    # Retrieve a single instance: 
    example, label = sess.run([features, col5]) 

    coord.request_stop() 
    coord.join(threads) 
+0

不幸的是,示例代碼(http://tensorflow.org/how_tos/reading_data/index.md)包含此錯誤 –

回答

13

該問題是由於程序中張量的形狀造成的。 TL; DR相反的tf.concat()你應該使用tf.pack(),這將改變四個標col張量爲長度爲4

在我們開始之前的1-d張量,注意,你可以使用get_shape()方法任何Tensor對象來獲取關於該張量的靜態形狀信息。例如,在代碼中註釋掉的行可能是:

print col1.get_shape() 
# ==> 'TensorShape([])' - i.e. `col1` is a scalar. 

通過reader.read()返回的value張量是一個標量字符串。 tf.decode_csv(value, record_defaults=[...])對於record_defaults的每個元素產生與value相同形狀的張量,即在這種情況下的標量。標量是具有單個元素的0維張量。 tf.concat(i, xs)未在標量上定義:它將N維張量列表(xs)連接成一個新的N維張量,沿着維度i,其中0 <= i < N,並且如果N = 0沒有有效的i

tf.pack(xs)運營商旨在簡單地解決這個問題。它列出了k N維張量(形狀相同),並將它們打包成第0維尺寸爲k的N + 1維張量。如果您在使用tf.concat()替換tf.pack(),你的程序將工作:

# features = tf.concat(0, [col1, col2, col3, col4]) 
features = tf.pack([col1, col2, col3, col4]) 

with tf.Session() as sess: 
    # Start populating the filename queue. 
    # ... 
+0

的CSV節需要澄清和擴大,這是肯定的。我用了大約15分鐘的時間來學習足夠的代碼,以使用tensorflow框架對我爲特定問題設計的神經網絡進行編碼。到目前爲止,花了半天時間才弄清楚如何閱讀一些非常簡單的csv數據。 – demongolem

+0

感謝您的反饋!如果您希望我們澄清一些具體的事情,請提出GitHub問題:https://github.com/tensorflow/tensorflow/issues – mrry

1

我也堅持這一tutorial。我能爲交換一個問題,當我改變你的with tf.Session()爲:

sess = tf.Session() 
coord = tf.train.Coordinator() 
threads = tf.train.start_queue_runners(coord=coord) 

for i in range(2): 
    #print i 
    example, label = sess.run([features, col5]) 

coord.request_stop() 
coord.join(threads) 

sess.close() 

錯誤消失,TF拔腿就跑,但它看起來像被卡住。如果您取消註釋# print,您將看到只有一個迭代運行。很可能這不是真的有用(因爲我爲無限執行交易錯誤)。