我創建了一個tfrecords
文件,我通過tf.TFRecordReader
讀取了這個文件,它非常適合培訓網絡。但是,我不知道如何動態地減少批量進行生產,也沒有如何餵養,並覆蓋有tf.train.import_meta_graph
如何使依賴於tf.train.shuffle_batch準備生產的網絡
reader = tf.TFRecordReader()
data = tf.train.shuffle_batch(...)
# batch_size 100
IS_TRAINING = tf.placeholder(tf.bool, shape=(), name="is_training")
# tried constant, variable and placeholder with no luck
custom_data = tf.Variable(...)
_data = tf.cond(
IS_TRAINING,
lambda: data,
lambda: custom_data,
name="condition"
)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
# network graph
coord.request_stop()
coord.join(threads)
sess.close()
我試着加載圖形與tf.train.import_meta_graph
導入培訓的圖形時,一些變量,使用feed_dict
,嘗試覆蓋IS_TRAINING
,所以圖表使用的數據也是我通過feed_dict
提供的。但迄今爲止沒有任何工作。
例如
sess.run([variable], feed_dict={IS_TRAINING:False, custom_data:data})
導入元圖只是爲您提供了模型的圖形定義(類似於解析和導入PB文件的效果),以及其他一些東西,比如您的保存。我相信meta_graph的導入與你的tfrecord管道相當不同。您可以考慮使用數據集類來讀取您的TFRecord文件,因爲它們將並行讀取您的數據並默認進行隨機播放,從而爲您提供更清晰的代碼。你可以考慮構建一個簡單地用「真」或「假」來改變模型作爲評估或訓練模型的模型。 – kwotsin