2016-11-25 29 views
1

描述here加載一些訓練圖像分批,即,基本上是這樣的:培訓VS測試與我使用的是設置隊列

def read_my_file_format(filename_queue): 
    # ... use a reader + a decoder 

def input_pipeline(filenames, batch_size, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(...) 
    example, label = read_my_file_format(filename_queue) 
    example_batch, label_batch = tf.train.shuffle_batch(
     [example, label], batch_size=batch_size, ...) 
    return example_batch, label_batch 

def build_net(): 
    batch, label = input_pipeline(...) 
    y = encoder(batch) # <- build network using the batch 

def train(): 
    with tf.Session() as sess: 
    # ... init vars 

    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    try: 
     while not coord.should_stop(): 
     # ... training step 

    except tf.errors.OutOfRangeError: 
     print('Done training -- epoch limit reached') 
    finally: 
     coord.request_stop() 

    coord.join(threads) 
    sess.close() 

這就是很好的訓練 - 但是,我怎麼沒看到我可以測試最終的網絡!什麼使我困惑:

  • input_pipeline返回的張量是網絡的一部分。爲了測試,我將不得不更換它?
  • 我想我可以創建另一個input_pipeline進行測試,即使用不同的文件名隊列。然後我可以使用tf.cond在不同的輸入批次之間切換,但是:如何確保一次只有一個隊列耗盡。我看不到如何訪問不同的隊列以及如何指定它們如何卸載。

基本上,這個問題歸結爲:什麼是測試網絡的規範方式使用tf.train.shuffle_batch方法構建。

回答

1

你是絕對正確的軌道創造了附加的輸入管道的想法上數據集評估。使用multiple input pipelines是推薦的方法之一,其將由兩個過程組成 - 一方面是訓練,另一方面是評估。檢查點將在訓練過程中使用,然後每千步驟,代碼可以嘗試針對訓練數據集和測試數據集兩者的模型eval

從文檔報價:

  • 訓練過程訓練讀取輸入數據,並定期與所有訓練的變量寫檢查點文件。
  • 評估過程將檢查點文件恢復爲讀取驗證輸入數據的推理模型。

即使在培訓完成/退出後也可以進行評估。 (see this example

另一個考慮是通過sharing variables train和eval可以在同一個過程中在同一個圖中操作,同時分享他們訓練過的變量!

關於您擁有的隊列耗盡問題,如果您使用tf.train.shuffle_batch*將num_threads設置爲大於1,它將同時從單個文件讀取(+比使用1個線程更快),而不是同時讀取N個文件(請參閱關於batching的部分)。

+0

聽起來不錯,我現在仔細看看這個 – fabian789

1

我的想法是使用一個字符串佔位符,即,假設你有多個輸入文件:

filenames_place = tf.placeholder(tf.string, shape=[None]) 
num_epochs_place = tf.placeholder(tf.int32) 
example_batch, label_batch = input_pipeline(filenames_place, batch_size, num_epochs_place) 
... 
try: 
    sess.run(train_op, feed_dict={filenames_place: ["train_data1", "train_data2"], num_epochs_place=5}) 

except tf.errors.OutOfRangeError: 
    print('Done training -- epoch limit reached') 

sess.run(eval_op, feed_dict={filenames_place: ["test_data"], num_epochs_place=1}) 
+0

這實際上工作嗎?我覺得'string_input_producer'創建後無法更改文件名,但不確定 – fabian789