2017-07-31 70 views
1

是否可以創建一個input_fn,它可以隨機生成隨機數據,以便與Tensorflow中的Estimator API一起使用?在Tensorflow的input_fn中生成無限的隨機訓練數據

這基本上是我想什麼:

def create_input_fn(function_to_generate_one_sample_with_label): 
    def _input_fn(): 
     ### some code ### 
     return feature_cols, labels 

然後,我會想使用的功能與Estimator情況是這樣的:

def data_generator(): 
    features = ... generate a (random) feature vector ... 
    lablel = ... create suitable label ... 
    return features, labels 

input_fn = create_input_fn(data_generator) 
estimator.train(input_fn=input_fn, steps=ANY_NUMBER_OF_STEPS) 

的一點是一定要能夠爲訓練儘可能多的步驟,即時生成所需的培訓數據。這是爲了模型調整目的,能夠試驗不同複雜度的不同訓練數據,以便我能夠了解模型適合訓練數據的能力。


編輯 作爲JKM建議,我試圖用一個實際的發電機,就像這樣:

def create_input_fn(function, batch_size=100): 
    def create_generator(): 
     while True: 
      features = ... generate <batch_size> feature vectors ... 
      lablel = ... create <batch_size> labels ... 
      yield features, label 
    g = create_generator() 
    def _input_fn(): 
     return next(g) 
    return _input_fn 

我不得不添加一個批量大小,以使之運行。它現在運行,但input_fn只被調用一次,所以它不會生成任何新的數據。它只是訓練生成的第一批<batch_size>樣本。有什麼方法告訴估算人員使用提供的input_fn刷新數據?

回答

0

警告詞 - 我做過而不是與Tensorflow自己一起工作,我只是關閉API的文檔。

這就是說 - 如果在那裏沒有問題,你應該可以做你需要的。只要讓發電機一個,發電機(產量特徵和標籤,而不是返回它們),並使整個一代無限循環。例如:

def data_generator(): 
    while True: 
     #do generatey things here 
     yield feature, labels 

該函數將能夠重複調用,每次每次調用產生一次新的值。

+0

感謝您的意見。它有幫助,但它仍然沒有做我想做的。請參閱編輯的問題:) –

1

我認爲你可以使用TF最近的數據集API所期望的行爲,你需要tensorflow> = 1.2.0

# Define number of samples and input shape for each iteration 
# you can set minval or maxval as per you data distribution and label distributon requirements 
num_samples = [20000,] 
input_shape = [32, 32, 3] 
dataset = tf.contrib.data.Dataset.from_tensor_slices((tf.random_normal([num_examples+input_shape]), tf.random_uniform([num_samples], minval=0, maxval=5))) 
# Define batch_size 
batch_size = 128 
dataset = dataset.batch(batch_size) 
# Define iterator 
iterator = dataset.make_initializable_iterator() 
# Get one batch 
next_example, next_label = iterator.get_next() 
# calculate loss from the estimator fucntion you are using 
estimator_loss = some_estimator(next_example, next_label) 
# Set number of Epochs here 
num_epochs = 100 
for _ in range(num_epochs): 
    sess.run(iterator.initializer) 
    while True: 
     try: 
      _loss = sess.run(estimator_loss) 
     except tf.errors.OutOfRangeError: 
      break 
+0

我認爲每次使用'sess.run(iterator.initializer)'調用迭代器每個'num_samples'步驟可能會有意義,以獲得新的隨機值。 – bodokaiser

0

我可以問:你執行數據增強生成你的數據?如果是這樣,只要使用tensorflow框架中的隨機函數,被調用的函數應該生成無限數量的隨機樣本。 (類似tf.random_uniform而不是來自numpy的相應方法等)。這對我很有用。