2016-09-17 30 views
0

調用predict功能需要的內存這是不是在我的GPU提供10GB:如何創建一個利用batch_size的輸入函數?

estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=model_dir) 
probs = estimator.predict(input_fn=lambda: my_input_fn(valid_records)) 

predict功能有batch_size參數使用input_fn時不可用。看來,我有兩個選擇(讓我知道如果有另一種):

  1. 更換input_fnx參數,然後利用batch_size PARAM。目前,我不知道該怎麼做!
  2. 修改我的輸入函數以不同批次返回數據。我不知道該怎麼做!

回答

0

利用tf.train.batch函數批量傳遞數據。將input_fn函數傳遞給預測方法時,請勿忘記設置as_iterable=True