2017-10-13 32 views
1

我跟隨this tutorial關於如何使用tf.scan,我寫了一個最小的工作示例(請參閱下面的代碼)。但是每次調用函數Model._step()時,是不是會創建計算圖的另一個副本?如果不是,爲什麼不呢?這個函數每次創建一個新的TensorFlow圖表嗎?

import tensorflow as tf 
import os 
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # to avoid TF suggesting SSE4.2, AVX etc... 

class Model(): 
    def __init__(self): 
     self._inputs = tf.placeholder(shape=[None], dtype=tf.float32) 
     self._predictions = self._compute_predictions() 

    def _step(self, old_state, new_input): 
     # ---- In here I will write a much more complex graph ---- 
     return old_state + new_input 

    def _compute_predictions(self): 
     return tf.scan(self._step, self._inputs, initializer = tf.Variable(0.0)) 

    @property 
    def predictions(self): 
     return self._predictions 

    @property 
    def inputs(self): 
     return self._inputs 

def test(sess, model): 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(model.predictions, {model.inputs: [1.0, 2.0, 3.0, 4.0]})) 

test(tf.Session(), Model()) 

我在問,因爲這當然是一個最小的例子,在我的情況下,我需要一個更復雜的圖。

回答

1

Model._step()方法只會被調用一次,每構建一個對象Modeltf.scan()函數,就像它包裝的tf.while_loop()函數一樣,只會調用它的給定函數一次來構建一個帶有循環的圖形,然後每個迭代循環都會使用同一個圖形。

(請注意,如果您構建多Model對象,你將最終獲得相同數量的圖的副本,你有Model對象。)

+0

謝謝。我以爲'tf.scan()'會在'Model.inputs'的每個元素上調用'Model._step()'。 – Ziofil

相關問題