2016-06-27 42 views
0

我需要多次執行語句sess.run()。 我在代碼的開頭創建了一次sess。但是,在我的CPU機器上,每個sess.run()語句需要將近0.5-0.8秒。有什麼辦法可以優化這個嗎?由於Tensorflow做延遲加載,有沒有什麼辦法可以讓它不做,並使其更快?爲多次運行優化Tensorflow

我使用圖像分類示例中的Inception模型。

def load_network(): 
    with gfile.FastGFile('model.pb', 'rb') as f: 
     graph_def = tf.GraphDef() 
     data = f.read() 
     graph_def.ParseFromString(data) 
     png_data = tf.placeholder(tf.string, shape=[]) 
     decoded_png = tf.image.decode_png(png_data, channels=3) 
     _ = tf.import_graph_def(graph_def, name=input_map={'DecodeJpeg': decoded_png}) 
     return png_data 

def get_pool3(sess, png_data, imgBuffer): 
    pool3 = sess.graph.get_tensor_by_name('pool_3:0') 
    pool3Vector = sess.run(pool3, {png_data: imgBuffer.getvalue()}) 
    return pool3Vector 

def main(): 
    sess = getTensorSession() 
    png_data = load_network() 

    # The below line needs to be called multiple times, which is what takes 
    # nearly 0.5-0.8 seconds. 
    # imgBuffer contains the stored value of the image. 
    pool3 = get_pool3(sess, png_data, imgBuffer) 
+0

你使用佔位符?使用隊列更高效 –

+0

我添加了一個佔位符,其餘部分是tensorflow給出的初始模型。 – n00b

回答

2

Tensorflow懶惰地運行操作---在調用sess.run()之前實際上沒有任何操作計算。當您調用sess.run()時,Tensorflow會執行您的計算圖中的所有操作。所以如果sess.run()需要0.5-0.8秒,那麼很可能你的計算本身需要0.5-0.8秒。

(有一些開銷sess.run(),但它不應該是接近半秒的順序的任何地方。)

希望幫助!

補充:

這裏有一些東西,你可能會考慮來加速計算起來:

  • 使用Tensorflow的分析工具來看看你的計算的一部分被抽空。他們還沒有記載,但你可以找到關於他們在這個問題GitHub的一些信息:https://github.com/tensorflow/tensorflow/issues/1824

  • 讓你的計算便宜---降低模型的複雜性,使用更小的圖像等

  • 在GPU上運行計算而不是CPU。

+0

嗨,我也加了我的代碼片段,如果你可以請檢查一次! – n00b

+0

是的,我認爲我的回答是正確的。當你調用sess.run()時,這實際上導致圖形直到操作符'pool3'執行。計算本身需要0.5s。 –

+0

對不起,請您詳細說明一下嗎?我應該做什麼改變? – n00b