我需要多次執行語句sess.run()
。 我在代碼的開頭創建了一次sess
。但是,在我的CPU機器上,每個sess.run()
語句需要將近0.5-0.8秒。有什麼辦法可以優化這個嗎?由於Tensorflow做延遲加載,有沒有什麼辦法可以讓它不做,並使其更快?爲多次運行優化Tensorflow
我使用圖像分類示例中的Inception模型。
def load_network():
with gfile.FastGFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
data = f.read()
graph_def.ParseFromString(data)
png_data = tf.placeholder(tf.string, shape=[])
decoded_png = tf.image.decode_png(png_data, channels=3)
_ = tf.import_graph_def(graph_def, name=input_map={'DecodeJpeg': decoded_png})
return png_data
def get_pool3(sess, png_data, imgBuffer):
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
pool3Vector = sess.run(pool3, {png_data: imgBuffer.getvalue()})
return pool3Vector
def main():
sess = getTensorSession()
png_data = load_network()
# The below line needs to be called multiple times, which is what takes
# nearly 0.5-0.8 seconds.
# imgBuffer contains the stored value of the image.
pool3 = get_pool3(sess, png_data, imgBuffer)
你使用佔位符?使用隊列更高效 –
我添加了一個佔位符,其餘部分是tensorflow給出的初始模型。 – n00b