以下代碼使用tensorflow庫,與numpy庫相比運行速度非常慢。我知道我正在調用一個函數,它使用python for循環中的tensorflow庫(我將稍後與python多處理進行並行化),但代碼的運行速度非常慢。tensorflow在python for循環內運行速度極慢
有人可以請幫助我如何讓這段代碼運行得更快嗎?謝謝。
from math import *
import numpy as np
import sys
from multiprocessing import Pool
import tensorflow as tf
def Trajectory_Fun(tspan, a, b, session=None, server=None):
# Open tensorflow session
if session==None:
if server==None:
sess = tf.Session()
else:
sess = tf.Session(server.target)
else:
sess = session
B = np.zeros(np.size(tspan), dtype=np.float64)
B[0] = b
for i, t in enumerate(tspan):
r = np.random.rand(1)
if r>a:
c = sess.run(tf.trace(tf.random_normal((4, 4), r, 1.0)))
else:
c = 0.0 # sess.run(tf.trace(tf.random_normal((4, 4), 0.0, 1.0)))
B[i] = c
# Close tensorflow session
if session==None:
sess.close()
return B
def main(argv):
# Parameters
tspan = np.arange(0.0, 1000.0)
a = 0.1
b = 0.0
# Run test program
B = Trajectory_Fun(tspan, a, b, None, None)
print 'Done!'
if __name__ == "__main__":
main(sys.argv[1:])
您正在緩慢地調整session.run調用之間的Graph對象。你可以在第一個'sess.run'前添加所有的操作並調用'tf.get_default_graph()。finalize()' –
@YaroslavBulatov感謝您的快速響應。正如你可能已經注意到的那樣,我需要每個時間步長的變量c的值。請您再澄清一下,我可以如何將您的建議納入我的上述代碼中?我會很感激。謝謝。 – QED
在循環開始之前做'a = tf.random_normal((4,4),0.0,1.0)',然後執行'sess.run(a)' –