2016-11-14 79 views
2

蟒蛇太慢了,我想在創建Tensorflow,函數,對於給定的數據X的每一行,正在申請的SOFTMAX功能僅適用於部分採樣的班,可以說2 ,在K個總類中,返回一個矩陣S,其中S.shape = (N,K)(N:給定數據的行數和K總類數)。Tensorflow是在for循環

矩陣S最終將包含零和由樣本類爲每一行定義的索引中的非零值。

在簡單的蟒蛇我使用高級索引,但在Tensorflow我無法弄清楚如何做到這一點。我最初的問題是this, where I present the numpy code

所以我試圖找到一個解決方案Tensorflow和主要想法是不使用S作爲二維矩陣,但作爲一維數組。代碼如下所示:

num_samps = 2 
S = tf.Variable(tf.zeros(shape=(N*K))) 
W = tf.Variable(tf.random_uniform((K,D))) 
tfx = tf.placeholder(tf.float32,shape=(None,D)) 
sampled_ind = tf.random_uniform(dtype=tf.int32, minval=0, maxval=K-1, shape=[num_samps]) 
ar_to_sof = tf.matmul(tfx,tf.gather(W,sampled_ind),transpose_b=True) 
updates = tf.reshape(tf.nn.softmax(ar_to_sof),shape=(num_samps,)) 
init = tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init) 
for line in range(N): 
    inds_new = sampled_ind + line*K 
    sess.run(tf.scatter_update(S,inds_new,updates), feed_dict={tfx: X[line:line+1]}) 

S = tf.reshape(S,shape=(N,K)) 

這是行得通的,結果是預期的。但它運行的非常緩慢。爲什麼會發生?我怎樣才能更快地完成這項工作?

回答

6

在張量流程中編程時,瞭解定義操作和執行它們之間的區別是至關重要的。大部分以tf.開頭的函數,當您在python 中運行時,將對計算圖添加操作。

例如,當你這樣做:

tf.scatter_update(S,inds_new,updates) 

還有:

inds_new = sampled_ind + line*K 

多次,你的計算圖形增長超出所需,填補所有的內存和巨大的速度變慢。

你應該做的,而不是對定義計算一次,在循環之前:

init = tf.initialize_all_variables() 
inds_new = sampled_ind + line*K 
update_op = tf.scatter_update(S, inds_new, updates) 
sess = tf.Session() 
sess.run(init) 
for line in range(N): 
    sess.run(update_op, feed_dict={tfx: X[line:line+1]}) 

這樣,你的計算圖只包含一個inds_newupdate_op的副本。請注意,執行update_op時,inds_new也會隱式執行,因爲它是計算圖中的父項。

你也應該知道,update_op可能會有不同的結果,每次運行時,它是罰款和預期。

順便說一句,調試這種問題的一個好方法是使用張量板來可視化計算圖。在代碼中添加:

summary_writer = tf.train.SummaryWriter('some_logdir', sess.graph_def) 

,然後在控制檯中運行:

tensorboard --logdir=some_logdir 

所服務的html頁面上會有計算圖,在那裏你可以檢查你的張量的圖片。

+0

非常感謝,這是工作和答案完全是我的問題!但問題仍然存在,創建矩陣S的numpy代碼仍然比這更快。而我只用張力低的功能......你知道爲什麼會發生這種情況嗎?我應該用C++創建一個新的操作來獲得加速嗎? –

+0

加快你的意思是提高20%,或者快20倍? CPU上的張量流速慢20%是預期的行爲。你有一個好的,支持CUDA的GPU(和使用它的張量流安裝)? Tensorflow適用於使用GPU /集羣的情況。 – sygi

0

請記住,tf.scatter_update將返回張量S,這意味着會話運行中的大內存副本,甚至分佈式環境中的網絡副本。解決的辦法是,根據@ sygi的回答是:

update_op = tf.scatter_update(S, inds_new, updates) 
update_op_op = update_op.op 

然後在會話中運行,你這樣做

sess.run(update_op_op) 

這將避免複製大張量S.