2016-10-13 73 views
1

我想用map_fn替換for循環,因爲後者似乎有助於提高循環效率。在TensorFlow中,如果在map_fn中調用get_variable,會發生什麼共享變量

問題是,如果map_fn中的fn調用get_variable()來創建一個新變量,那麼如何在循環的其餘部分將重用設置爲True?或者get_variable()只在map_fn中調用過一次?

def fn(x): 
    y = tf.get_variable('y', []) 
    return x * x 

squares = tf.map_fn(fn, np.array([1, 2, 3, 4 ,5 ,6])) 

# Out: [array([ 1, 4, 9, 16, 25, 36])] 
sess.run([squares]) 
+0

'tf.all_variables()'只顯示一個變量,因此它看起來只會被調用一次 –

+0

謝謝@YaroslavBulatov!我想你是對的。 – user2309694

回答

0
In [2]: def fn(x):           
      y = tf.get_variable('y', []) 
      print(y.name) 
      return x * x 
In [4]: import numpy as np 
In [5]: squares = tf.map_fn(fn, np.array([1, 2, 3, 4 ,5 ,6])) 
y:0 

正如我們可以看到,如果打印插入FN,當它被稱爲只會打印一次。

相關問題