2016-10-25 142 views
0

以下代碼使用tensorflow庫,與numpy庫相比運行速度非常慢。我知道我正在調用一個函數,它使用python for循環中的tensorflow庫(我將稍後與python多處理進行並行化),但代碼的運行速度非常慢。tensorflow在python for循環內運行速度極慢

有人可以請幫助我如何讓這段代碼運行得更快嗎?謝謝。


from math import * 
import numpy as np 
import sys 
from multiprocessing import Pool 
import tensorflow as tf 

def Trajectory_Fun(tspan, a, b, session=None, server=None): 
    # Open tensorflow session 
    if session==None: 
     if server==None: 
      sess = tf.Session() 
     else: 
      sess = tf.Session(server.target)  
    else: 
     sess = session 
    B = np.zeros(np.size(tspan), dtype=np.float64) 
    B[0] = b 
    for i, t in enumerate(tspan): 
     r = np.random.rand(1) 
     if r>a: 
      c = sess.run(tf.trace(tf.random_normal((4, 4), r, 1.0))) 
     else: 
      c = 0.0 # sess.run(tf.trace(tf.random_normal((4, 4), 0.0, 1.0))) 
     B[i] = c 
    # Close tensorflow session 
    if session==None: 
     sess.close() 
    return B 

def main(argv): 
    # Parameters 
    tspan = np.arange(0.0, 1000.0) 
    a = 0.1 
    b = 0.0 
    # Run test program 
    B = Trajectory_Fun(tspan, a, b, None, None) 
    print 'Done!' 

if __name__ == "__main__": 
    main(sys.argv[1:]) 
+0

您正在緩慢地調整session.run調用之間的Graph對象。你可以在第一個'sess.run'前添加所有的操作並調用'tf.get_default_graph()。finalize()' –

+0

@YaroslavBulatov感謝您的快速響應。正如你可能已經注意到的那樣,我需要每個時間步長的變量c的值。請您再澄清一下,我可以如何將您的建議納入我的上述代碼中?我會很感激。謝謝。 – QED

+0

在循環開始之前做'a = tf.random_normal((4,4),0.0,1.0)',然後執行'sess.run(a)' –

回答

2

正如你的問題說,因爲它創造了每運行幾個新TensorFlow圖節點這一計劃將給予表現不佳。 TensorFlow中的基本假設是(大約)您將構建一次圖形,然後多次調用sess.run()(的各個部分)。您第一次運行圖形相對昂貴,因爲TensorFlow必須構建各種數據結構並優化跨多個設備的圖形執行。然而,TensorFlow緩存了這項工作,所以後續使用便宜得多。

通過構建一次圖並使用(例如)tf.placeholder() op來提供每次迭代中更改的值,可以使該程序快得多。例如,下面應該做的伎倆:

B = np.zeros(np.size(tspan), dtype=np.float64) 
B[0] = b 

# Define the TensorFlow graph once and reuse it in each iteration of the for loop. 
r_placeholder = tf.placeholder(tf.float32, shape=[]) 
out_t = tf.trace(tf.random_normal((4, 4), r_placeholder, 1.0)) 

with tf.Session() as sess: 
    for i, t in enumerate(tspan): 
    r = np.random.rand(1) 
    if r > a: 
     c = sess.run(out_t, feed_dict={r_placeholder: r}) 
    else: 
     c = 0.0 
    B[i] = c 
    return B 

你可能使這更有效的利用TensorFlow循環和sess.run()使得更少的調用,但總的原則是一樣的:重複使用相同的圖形多次獲得TensorFlow的好處。