2017-08-17 25 views
1

我有一個虛擬的csv文件(y=-x+1張量流的csv文件讀取器中的num_epochs是否有限制string_input_producer()?

x,y 
1,0 
2,-1 
3,-2 

我試喂的是成線性迴歸模型。由於我只有這麼幾個例子,我想在該文件上迭代1000次訓練,所以我設置了num_epochs=1000

但是,似乎Tensorflow限制了這個數字。它工作正常,如果我使用num_epochs = 5或10,但超過33它被限制在33個時代。這是真的還是我做錯了什麼?

# model = W*x+b 
... 
optimizer = tf.train.GradientDescentOptimizer(0.01) 
train = optimizer.minimize(loss) 

# reading input from csv 
filename_queue = tf.train.string_input_producer(["/tmp/testinput.csv"], num_epochs=1000) 
reader = tf.TextLineReader(skip_header_lines=1) 
... 
col_x, col_label = tf.decode_csv(csv_row, record_defaults=record_defaults) 

with tf.Session() as sess: 
    sess.run(tf.local_variables_initializer()) 
    sess.run(tf.global_variables_initializer()) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    while True: 
    try: 
     input_x, input_y = sess.run([col_x, col_label]) 
     sess.run(train, feed_dict={x:input_x, y:input_y}) 
... 

方的問題,做我需要做的:

input_x, input_y = sess.run([col_x, col_label]) 
sess.run(train, feed_dict={x:input_x, y:input_y}) 

我試圖sess.run(train, feed_dict={x:col_x, y:col_y})直接避免了摩擦,但它不工作(他們是節點,feed_dict預計常規數據)

回答

0

下面演示作品完美(與輸入):

import tensorflow as tf 


filename_queue = tf.train.string_input_producer(["/tmp/input.csv"], num_epochs=1000) 
reader = tf.TextLineReader(skip_header_lines=1) 
_, csv_row = reader.read(filename_queue) 
col_x, col_label = tf.decode_csv(csv_row, record_defaults=[[0], [0]]) 

with tf.Session() as sess: 
    sess.run(tf.local_variables_initializer()) 
    sess.run(tf.global_variables_initializer()) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    num = 0 
    try: 
     while True: 
     sess.run([col_x, col_label]) 
     num += 1 
    except: 
    print(num) 

這給下面的輸出:

[email protected]:/tmp$ python csv.py 
3000 
+0

對我來說不num_epochs更大的值(例如100.000)或更長的csv文件的工作(例如,相同線路重複1000次) – Thomas

+0

我只是重複的測試上面的代碼有100.000個紀元和甚至1.000.000個紀元在我的結尾都給出了相同的結果......也許在快速CPU和慢IO之間會出現一些競爭狀況? –

+0

謝謝你嘗試自己,給了我力量繼續(掙扎的小時)。所以罪魁禍首是:'嘗試:除了tf.errors.OutOfRangeError:最後:coord.request_stop()'(最後......破壞它)和'coord.join(線程)'。我在互聯網上的很多地方看到,我沒有考慮刪除它們,但沒有更好(或更好:以錯誤的順序將它們疊加)。謝謝! – Thomas

相關問題