2017-07-19 18 views
1

我正在努力將我的(雜亂)代碼從tensorflow核心傳遞到Estimator範例,特別是使用Experiments - 與learn_runner.run。但我實際上有問題將數據提供給我的神經網絡。TensorFlow實驗:如何避免用input_fn加載內存中的所有數據?

我想要實現的東西其實非常接近TensorFlow和tf.TextLineReader的所有例子,例如, https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/customestimator/trainer/model.py#L297,雖然我不是從磁盤上的文件加載數據,而是使用網絡服務。

從我的理解(並在看代碼tensorflow.python.estimator._train_model()input_fn只被調用一次,而不是在每次迭代。我可以輕鬆地將我的所有數據,然後做一些事情,如:

def input_fn(): 
    data = # all data in memory 
    batch = tf.train.input_producer(tf.constant(data)) 
    return batch.dequeue_many(batch_size) 

但我的數據不會裝入內存,這是不可持續的。我試圖做類似的事情:

1. load first piece of data (say N lines) 
2. consume it by batches in a queue just like the input_fn above 
2'. feed this queue asynchronously with new data when it's almost empty 

我知道如何在「純」的tf, How to prefetch data using a custom python function in tensorflowTensorflow: custom data load + asynchronous computation,但我發現很難將其轉置爲Experiment範例,因爲我無法訪問會話來自行加載內容,也無法在內部添加操作。

編輯

我設法使用tf.py_func(),像這樣做:

class Reader(object): 
    # a Python object that can load data and have some intelligence, not related to TF, initialized with batch_sized 

    def read_up_to(self): 
     """Reads up to batch_size elements loaded in Python""" 

def input_fn(): 
    reader = Reader() # instantiated once 
    return tf.py_func(reader.read_up_to, inp=[], Tout=...) 

我工作得很好,雖然這是一個有點慢(如預期有一個解決方法從C++執行到Python導致約50%的延遲)。我試圖通過在讀取器中異步讀入特定的TensorFlow隊列來解決這個問題,這樣就可以在不將數據從Python傳遞到C++的情況下完成加載(就像上面的兩個鏈接一樣)。

回答

相關問題