2017-11-25 85 views
2

最近我想實現GAN模型,並使用tf.Dataset和Iterator來讀取人臉圖像作爲訓練數據。如何在tensorflow中運行Iteration.get_next()後獲得批處理批處理?

數據集和迭代器對象的代碼是:

self.dataset = tf.data.Dataset.from_tensor_slices(convert_to_tensor(self.data_ob.train_data_list, dtype=tf.string)) 
self.dataset = self.dataset.map(self._parse_function) 
#self.dataset = self.dataset.shuffle(buffer_size=10000) 
self.dataset = self.dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) 

self.iterator = tf.data.Iterator.from_structure(self.dataset.output_types, self.dataset.output_shapes) 
self.next_x = self.iterator.get_next() 

我的新GAN模式是:

self.z_mean, self.z_sigm = self.Encode(self.next_x) 
self.z_x = tf.add(self.z_mean, tf.sqrt(tf.exp(self.z_sigm))*self.ep) 
self.x_tilde = self.generate(self.z_x, reuse=False) 
#the feature 
self.l_x_tilde, self.De_pro_tilde = self.discriminate(self.x_tilde) 

#for Gan generator 
self.x_p = self.generate(self.zp, reuse=True) 
# the loss of dis network 
self.l_x, self.D_pro_logits = self.discriminate(self.next_x, True) 

那麼,問題是,我用的是self.next_x作爲輸入兩次張。每次的數據集都是不同的。那麼,如何解決這個問題以保持第一批重用呢?

回答

1

我在代碼中使用什麼是以下,其中Xy_true是佔位符。不確定是否有更有效的實施。

images, labels = session.run(next_element) 
batch_accuracy = session.run(accuracy, feed_dict={x: images, y_true: labels, keep_prop: 1.0}) 
batch_predicted_probabilities = session.run(y_pred, feed_dict={x: images, y_true: labels, keep_prop: 1.0}) 

我目前正在嘗試使用tf.placeholder_with_default,而不是x和y_true正常佔位符的檢查,如果它給我的項目更好的性能。如果我設法儘快得到任何結果,請編輯我的答案讓你知道。

編輯: 我切換到placeholder_with_default並沒有給出明顯的速度提升每批至少在我測量的方式。

+0

謝謝,我認爲這個工具是正確的。 – zhangqianhui

+0

隨時。請執行它,如果它工作並沒有給出更好的答案,請將我的答案標記爲正確! –

+0

你的情況對你有好處或壞處? 我切換到了placeholder_with_default,並且每批次的速度都沒有明顯提高,至少在我測量它的方式上。 –