2017-08-31 18 views
1

我想在圖中創建一個計數器。只要函數運行,它就會增加。在這裏,我現在正在嘗試。有任何想法嗎?在張量流圖中創建計數器

# init variable 
tf_i = tf.Variable(1, name='v', dtype=tf.int32, expected_shape=()) 

# init assign op 
tf_i_plus_one = tf.assign_add(tf_i, 1) 
sess.run(tf.global_variables_initializer()) 

# simple test fx 
def test(x): 
    v = tf.get_variable('v', shape=()) 
    return x + v 

# run test fx 
z = tf.placeholder(tf.float32, shape=(), name='z') 
out = test(z) 

print(sess.run(out, feed_dict={z: 4})) 
sess.run(tf_i_plus_one) 

print(sess.run(out, feed_dict={z: 4})) 

我現在得到:

Attempting to use uninitialized value v_1 

回答

0

如果你想要的是一個功能的計數器,請嘗試以下操作:

def wrap_with_counter(fn, counter): 
    def wrapped_fn(*args, **kwargs): 
     # control_dependencies forces the assign op to be run even if we don't use the result 
     with tf.control_dependencies([tf.assign_add(counter, 1)]): 
      return fn(*args, **kwargs) 
    return wrapped_fn 

用法示例:

dense_counter = tf.get_variable(
    dtype=tf.int32, shape=(), name='dense_counter', 
    initializer=tf.zeros_initializer()) 
wrapped_dense = wrap_with_counter(tf.layers.dense, dense_counter) 


batch_size = 4 
n_features = 8 
# dummy input 
x = tf.random_normal(
    shape=(batch_size, n_features), dtype=tf.float32, name='x') 

out = wrapped_dense(x, 1) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    c0 = sess.run(dense_counter) 
    print(c0) # 0 
    sess.run(out) 
    c1 = sess.run(dense_counter) 
    print(c1) # 1 
+0

謝謝。超級有用。但是,爲了理解tf如何處理圖中的變量賦值或更新,您對此有何看法? –

+0

作爲一個經驗法則,在實例化會話/執行變量初始化之前完全構建圖形。你得到的錯誤讓我感到驚訝 - 我會期待更多的東西,比如'變量v已經存在,不允許',但是如果你將圖的構造與會話分開,事情會變得更加可預測。 'tf.assign_add'返回表示更新結果的張量,但如果該結果不依賴於在會話中調用的任何內容,將不會運行,因此更新將不會運行。添加'tf.control_dependencies'是解決這個問題的一種方法。 – DomJack