2017-05-03 55 views
2

大多數教程重點放在整個訓練數據集裝入內存的情況下input_fn`。但是,我有一個迭代器,它可以作爲(特徵,標籤)的無限流--tuples(在運行中便宜地創建它們)。創建`從iterator

在實現input_fn爲tensorflows estimator,我可以從迭代器返回實例作爲

def input_fn(): 
    (feature_batch, label_batch) = next(it) 
    return tf.constant(feature_batch), tf.constant(label_batch) 

或不input_fn必須返回相同的(功能,標籤)元組在每次調用?

而且是這個函數在訓練中多次調用,因爲我希望它像以下僞:

for i in range(max_iter): 
    learn_op(input_fn()) 

回答

2

input_fn的參數在整個訓練使用,但功能本身被調用一次。因此,創建一個複雜的input_fn不僅僅如tutorial中所解釋的那樣返回一個常量數組並不那麼簡單。

Tensorflow提出numpypanda陣列,非平凡input_fn的兩個例子,但他們從內存中的數組開始,所以這並不能幫助你解決問題。

你也可以通過上面的鏈接看看他們的代碼,看看他們如何實現一個有效的非平凡的input_fn,但你可能會發現它需要更多的代碼,你想要的。

如果你願意用Tensorflow的高層次少接口,事情恕我直言,更簡單,更靈活。有一個覆蓋大多數需求的tutorial,所提出的解決方案很容易實現。

特別是,如果你已經有了,你在你的問題中所述,(在以前的鏈接部分「餵養」)使用佔位符,返回數據的迭代器應該是簡單的。

+1

我會想到供給從一個迭代的網絡/迭代是標準的用例,沒有例外。 –

2

我發現,其中轉換的generatorinput_fn pull請求: https://github.com/tensorflow/tensorflow/pull/7045/files

相關部分是

def _generator_input_fn(): 
    """generator input function.""" 
    queue = feeding_functions.enqueue_data(
     x, 
     queue_capacity, 
     shuffle=shuffle, 
     num_threads=num_threads, 
     enqueue_size=batch_size, 
     num_epochs=num_epochs) 

    features = (queue.dequeue_many(batch_size) if num_epochs is None 
       else queue.dequeue_up_to(batch_size)) 
    if not isinstance(features, list): 
     features = [features] 
    features = dict(zip(input_keys, features)) 
    if target_key is not None: 
     if len(target_key) > 1: 
     target = {key: features.pop(key) for key in target_key} 
     else: 
     target = features.pop(target_key[0]) 
     return features, target 
    return features 
    return _generator_input_fn