2017-10-09 51 views
1

我正在實施一個損失函數,該函數將使用由0s and 1s組成的掩模張量(M)來消除一些損失值,給出預測(P)和地面實況(G)張量。哪一個更有效率:tf.where還是element-wise乘法?

所以,我有2種可能的方式:

逐元素乘法:

loss = K.sum(M * K.binary_crossentropy(G, P))

條件選擇:

bin_ce = K.binary_crossentropy(G, P) 
loss = K.sum(tf.where(tf.equal(M, 1), bin_ce, 0)) 

所以,這將是更在運行時間方面效率高嗎?

+1

你自己運行任何基準? –

+0

我正在運行一個基準測試,但尚未完成。我事先徵求您的意見。 – mkocabas

+1

我相當相信乘法情況會更好....等待您的測試結果。我無法想象使用少於2步的第二種情況。 –

回答

1

我做了基準,很明顯乘法比條件選擇好得多。

下面是結果:

A chart is worth a thousand words..

圖表是勝過千言萬語。

基準代碼:

import keras.backend as K 
import tensorflow as tf 
import numpy as np 
import sys 
import time 
import matplotlib.pyplot as plt 


def elm(G, P, M): 
     return K.sum(M * K.binary_crossentropy(G, P)) 

def cond(G, P, M, t): 
     C = K.variable(np.zeros((t, t))) 
     bin_ce = K.binary_crossentropy(G, P) 
     return K.sum(tf.where(tf.equal(M, 1), bin_ce, C)) 


s = [100, 1000, 10000, 100000] 
elms = [] 
conds = [] 

for t in s: 
     print t 
     t = int(t) 
     # number of 1s in mask 
     n = int(t/2) 

     M = np.zeros((t,t)) 
     P = np.random.rand(t, t) 
     G = np.random.rand(t, t) 

     for i in range(n): 
       r = np.random.randint(0, t) 
       c = np.random.randint(0, t) 
       M[r,c] = 1 

     M = K.variable(M) 
     P = K.variable(P) 
     G = K.variable(G) 

     start_time = time.time() 
     elm(G, P, M) 
     elms.append(time.time() - start_time) 

     start_time = time.time() 
     cond(G, P, M, t) 
     conds.append(time.time() - start_time) 

print elms 
print conds 

# create plot 
fig, ax = plt.subplots() 
index = np.arange(n_groups) 
bar_width = 0.35 
opacity = 0.8 

rects1 = plt.bar(index, elms, bar_width, 
       alpha=opacity, 
       color='b', 
       label='Element-wise') 

rects2 = plt.bar(index + bar_width, conds, bar_width, 
       alpha=opacity, 
       color='g', 
       label='Conditional') 

plt.xlabel('Input tensor size') 
plt.ylabel('Execution time (s)') 
plt.title('') 
plt.xticks(index + bar_width, ('100', '10e3', '10e4', '10e5')) 
plt.legend() 

plt.tight_layout() 
plt.show() 
相關問題