2016-03-24 35 views
2

standard prefetching queue之後,我通過一些額外的驗證代碼擴展了上述示例,請參閱附加代碼。也就是說,每個第i個訓練步驟中,學習模型都在驗證集上進行評估(在我的情況下爲幾個)。驗證集不能通過隊列饋送,因此一個可能的想法是使用共享變量來構建額外的推理圖。Tensorflow:通過協調器停止線程似乎不起作用?

這不知何故,但訓練結束後,程序掛起(在coord.join()),並最終拋出異常:Coordinator stopped with threads still running:...,然後異步加載線程也會引發異常。 coordinator異常可以通過try/except子句來解決(請參閱下面的代碼),但異步線程仍會拋出一個異常(但不妨礙主程序,但不應該發生在我的意見中 - 它有while循環應該告訴它停止)。

有趣的是,如果在沒有任何評估代碼運行的情況下(即if (it+1)%stop == 0:註釋後的代碼塊)完成培訓,則coord.join()根本不會掛起。

我的問題:我在這裏做錯了什麼?看起來好像.request_stop()沒有做我希望它應該做的?

import tensorflow as tf 
import numpy as np 

# some parameters 
btsz = 100 # batch size 
some_shape = 20 # size of one input (no of dims) 
iters = 1000 # that many single training steps 
ith = 10 # run validation sets every so often 
# datastores (sort of complex backends, SQL like) 
ds_train = ... # the one for training 
ds_val1, ds_val2, ds_val3 = ... # having the validation data 

def async_load(coord, session, queue, datastore, 
       tf_input, tf_target): 
    """ 
    Feed queue in async way. Inputs can be extracted 
    from datastore only one row at a time. 
    """ 
    while not coord.should_stop(): 
     input = extract_one_input_as_numpy(datastore) 
     target = extract_numpy_from(datastore) # either 0 or 1 
     session.run(queue, feed_dict={tf_input: input, tf_target: target}) 

def evaluate(sess, datastore, tf_input, tf_target, tf_loss, btsz): 
    """ 
    Evaluate current model (represented as tf_loss) on a datastore. 
    """ 
    loss = [] 
    for i in xrange(something): 
     input_batch = collect_btsz_many_single examples(datastore) 
     target_batch = same_for_targets(datastore) 
     tmp, = sess.run([tf_loss], feed_dict={tf_input:input_batch, tf_target:target_batch}) 
     loss.append(tmp) 
    return np.mean(loss) 

def log_reg(input, target, W, b): 
    """ 
    Simple logistic regression model. 
    """ 
    y = tf.matmul(input, W) + b 
    y_bin = tf.to_int32(y > 0) 

    t_bin = tf.to_int32(target > 0) 

    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y, targets=target)) 
    correct_prediction = tf.equal(y_bin, t_bin) 
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 
    return y, loss, accuracy 

with tf.Session() as sess: 
    # Placeholders to represent one input/target pair from a data store. 
    ds_inpt = tf.placeholder(tf.float32, shape=[some_shape]) 
    ds_trgt = tf.placeholder(tf.float32, shape=[]) 

    queue = tf.FIFOQueue(capacity=10000, dtypes=[tf.float32, tf.float32], 
        shapes=[[], [some_shape], shared_name="FIFO", name="FIFO") 

    # enqueuing, this will be used in the async loading. 
    enqueue_op = queue.enqueue([ds_trgt, ds_inpt]) 

    # dequeue from queue q, with batch size btsz 
    q_trgt, q_inpt = queue.dequeue_many(btsz) 

    # Paramters for Logistic Regression 
    # two functions that build shared variables and initialize these 
    W = weight_variable([some_shape, 1]) 
    b = bias_variable([1]) 

    # training model, feed from dequeuing the async queue 
    y, loss, accuracy = log_reg(input=q_inpt, target=q_trgt, W=W, b=b) 

    train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) 

    # inputs for validation models 
    val_inpt = tf.placeholder(tf.float32, shape=[btsz, some_shape]) 
    val_trgt = tf.placeholder(tf.float32, shape=[btsz]) 
    # validation model 
    val_y, val_loss, val_accuracy = log_reg(input=val_inpt, target=val_trgt, W=W, b=b) 

    sess.run(tf.initialize_all_variables()) 
    try: 
     coord = tf.train.Coordinator() 
     # Start a thread to enqueue data asynchronously, and hide I/O latency. 
     t = threading.Thread(target=async_load, 
           args=(coord, sess, enqueue_op, ds_train 
            ds_inpt, ds_trgt)) 
     t.start() 

     # collect loss/accuracy for training 
     # and losses for validation/test sets. 
     tr_loss = [] 
     tr_acc = [] 
     v_loss = [] 

     for it in xrange(iters): 
      _, _loss, _acc = sess.run([train_step, loss, accuracy]) 
      tr_loss.append(_loss) 
      tr_acc.append(_acc) 
      if (it+1)%stop == 0: 
       # run trained model on validation set 1 
       tmp = evaluate(sess=sess, data=ds_val1, 
           tf_inpt=val_inpt, tf_trgt=val_trgt, 
           tf_loss=val_loss, btsz) 
       v_loss.append(tmp) 
       # run trained model on validation set 2 
       tmp = evaluate(sess=sess, data=ds_val2, 
           tf_inpt=val_inpt, tf_trgt=val_trgt, 
           tf_loss=val_loss, btsz) 
       v_loss.append(tmp) 
       # run trained model on validation set 3 
       tmp = evaluate(sess=sess, data=ds_val3, 
           tf_inpt=val_inpt, tf_trgt=val_trgt, 
           tf_loss=val_loss, btsz) 
       v_loss.append(tmp) 
     coord.request_stop() 
     coord.join([t]) 
    except RuntimeError as rte: 
     print("Caught {}".format(rte)) 
# Clear everything! 
tf.reset_default_graph() 

回答

5

您的代碼中存在爭用條件。運行async_load()該線程將永遠阻塞,如果發生下列事件發生:

  1. async_load()電話coord.should_stop()返回False
  2. async_load()調用session.run(queue, ...)但是隊列已滿,因此調用將無限期阻塞。
  3. 主線程調用coord.request_stop()
  4. 主線程調用coord.join([t]),並且由於(2)而永遠阻塞。爲了避免這種

一種方法是創建一個queue.close(cancel_pending_enqueues=True)運算,並在主線程中調用coord.request_stop()之前運行它。這將打開async_load()線程,並啓用coord.join([t])返回。

+0

這會導致加載線程中出現'''CanceledError'''''''''' AbortedError'',這應該會發生,並且只需要被捕獲,是否正確? – osdf

+0

與原始問題沒有嚴格關係,但與代碼相關:'''q_trgt''','''''q_inpt''',可以直接提供給評估部分?因此,如果用在飼料字典中,隊列將不會被輪詢,是否正確?這將避免額外的'''val_trgt'','''val_inpt'''佔位符。 – osdf

+0

謝謝你的答案! :) – osdf