我正在努力將我的(雜亂)代碼從tensorflow核心傳遞到Estimator
範例,特別是使用Experiments
- 與learn_runner.run
。但我實際上有問題將數據提供給我的神經網絡。TensorFlow實驗:如何避免用input_fn加載內存中的所有數據?
我想要實現的東西其實非常接近TensorFlow和tf.TextLineReader
的所有例子,例如, https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/census/customestimator/trainer/model.py#L297,雖然我不是從磁盤上的文件加載數據,而是使用網絡服務。
從我的理解(並在看代碼tensorflow.python.estimator._train_model()
)input_fn
只被調用一次,而不是在每次迭代。我可以輕鬆地將我的所有數據,然後做一些事情,如:
def input_fn():
data = # all data in memory
batch = tf.train.input_producer(tf.constant(data))
return batch.dequeue_many(batch_size)
但我的數據不會裝入內存,這是不可持續的。我試圖做類似的事情:
1. load first piece of data (say N lines)
2. consume it by batches in a queue just like the input_fn above
2'. feed this queue asynchronously with new data when it's almost empty
我知道如何在「純」的tf, How to prefetch data using a custom python function in tensorflow或Tensorflow: custom data load + asynchronous computation,但我發現很難將其轉置爲Experiment
範例,因爲我無法訪問會話來自行加載內容,也無法在內部添加操作。
編輯
我設法使用tf.py_func()
,像這樣做:
class Reader(object):
# a Python object that can load data and have some intelligence, not related to TF, initialized with batch_sized
def read_up_to(self):
"""Reads up to batch_size elements loaded in Python"""
def input_fn():
reader = Reader() # instantiated once
return tf.py_func(reader.read_up_to, inp=[], Tout=...)
我工作得很好,雖然這是一個有點慢(如預期有一個解決方法從C++執行到Python導致約50%的延遲)。我試圖通過在讀取器中異步讀入特定的TensorFlow隊列來解決這個問題,這樣就可以在不將數據從Python傳遞到C++的情況下完成加載(就像上面的兩個鏈接一樣)。