2017-08-04 106 views
0

這個張量流代碼沒有響應,我找不出原因。請幫忙!Tensorflow以無限循環結束

import tensorflow as tf 
#reading the file 
with tf.name_scope ('File_reading') as scope: 
    filename_queue = tf.train.string_input_producer(["forestfires.csv.digested"]) 
    reader = tf.TextLineReader() 
    key, value = reader.read(filename_queue) 
    record_defaults = [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.0]] 
    #13 decoded 
    col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, col13 = tf.decode_csv(
     value, record_defaults=record_defaults) 


    #12 is feture, and the 13th is the training data 
    features = tf.stack([col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12],name='data_input') 

    with tf.Session() as sess: 
     # Start populating the filename queue. 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     for i in range(517): 
      # Retrieve a single instance: 
      example, label = sess.run([features, col13]) 

     coord.request_stop() 
     coord.join(threads) 
with tf.name_scope ('network') as scope: 
    W1=tf.Variable(tf.zeros([12, 8]), name='W1') 
    b1=tf.Variable(tf.zeros([8]), name="b1") 
    h1=tf.add(tf.matmul(tf.expand_dims(features,0), W1),b1, name='hidden_layer') 
    W2=tf.Variable(tf.zeros([8, 1]), name='W2') 
    b2=tf.Variable(tf.zeros([1]), name="b2") 
    output=tf.add(tf.matmul(h1, W2),b2, name='output_layer') 
error=tf.add(output,-col13, name='error') 
#training 
train_step = tf.train.AdamOptimizer(1e-4).minimize(error) 
#graphing the output 
file_writer = tf.summary.FileWriter('some directory', sess.graph) 
with tf.Session() as sess: 
    #init 
    tf.global_variables_initializer().run() 
    print ("\n\n\n\n\n\nTRAINING STARTED\n\n\n\n\n\n") 
    print('test1') 
    sess.run(error) #this statement causes an infinite loop 
    print ('test2') 
file_writer.close() 

該代碼運行並打印'test1',但它什麼都不做,甚至沒有響應ctrl + c。我試圖查找問題,但是我的谷歌技能不夠好,或者它不在互聯網上。 system:win10 geforce 960M python 3.5.2

回答

0

您構建網絡的方式在智力上並不會使敏感。如果您需要從TextLineReader讀取517個步驟,請使用函數read_up_to並提供值517,而不是使用單獨的會話。按照您構建圖表的方式,輸入閱讀器與圖形的其餘部分之間似乎沒有一個簡潔的連接。

我的建議:

# define graph which includes the input queue 
def model(...): 
... 
    return error, metrics 

with tf.Graph.as_default(): 
    error, metrics = model(...) 

    with tf.Session(): 
    # Start Coordinator 
    # Initialise global vars 
    # Start queue runners 
    # model_error, model_metrics = sess.run([error, metrics]) 
0

解決它(這個錯誤),這不是一個無限循環,它只是等待輸入數據。出於某種原因,如果我將上面的'with tf.Session()as sess:'塊(不帶with部分)粘貼到塊的頂部,它會很好地運行。 (也許有可能,還有一些其他編碼錯誤,因爲自那以後我改變了一些東西。)