2016-11-22 38 views
0

我正在用Tensorflow構建一個深度學習模型。在訓練之前,我會做一些計算,如反向傳播。但只需要計算一次。下面是我的僞代碼:在Tensorflow中只計算一次子圖

class residual_net() 
    def pseudo_bp(self): 
     # do something... 
     self.bp = ... 

    def build_net(self): 
     # build a residual_network.... 
     # utilize the variable in pseudo_bp 
     rn.output = func(self.bp) 

def run(): 
    rn = residual_net() 
    rn.pseudo_bp() 
    rn.deep_residual_network() 
    sess = tf.InteractiveSession() 
    sess.run(tf.initialize_all_variables()) 
    for i in range(1000): 
     err = tf.reduce_mean(rn.output, labels) 
     train = tf.train.GradientDescentOptimizer(learning_rate).minimize(err) 
     sess.run(train, feed_dict=train_feed_dict) 

不知pseudo_bp將在每次迭代運行?如果是的話,我怎麼才能讓它運行一次?提前致謝!

編輯: 最新的錯誤:

Traceback (most recent call last): 
    File "run.py", line 124, in <module> 
    sess.run(pseudo_bp, feed_dict=feed_dict) 
    File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 717, in run 
    run_metadata_ptr) 
    File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 902, in _run 
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string) 
    File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 358, in __init__ 
    self._fetch_mapper = _FetchMapper.for_fetch(fetches) 
    File "/Users/yobichi/bigdata/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 178, in for_fetch 
    (fetch, type(fetch))) 
TypeError: Fetch argument None has invalid type <class 'NoneType'> 

你有什麼想法?

回答

0

在TensorFlow中,您從構建tf.Graph開始。該圖由變量,操作和佔位符組成。然後開始tf.Session(),您可以在其中執行操作並更新變量。

在這種情況下,我認爲psuedo_bp最終需要您計算一些操作(如tf.matmul)。 sess就像一個指針,只要你運行sess.run(op)就會執行各種tf.Operation。您提供一些輸入來填充佔位符(feed_dict)。

因此,您只會執行sess.run(op) for for循環的第一次迭代。這裏是結果代碼 -

class residual_net() 
    def pseudo_bp(self): 
     # do something... 
     return op 

    def build_net(self): 
     # build a residual_network.... 
     rn.output = sth 

def run(): 
    rn = residual_net() 
    operation = rn.pseudo_bp() 
    rn.build_net() 
    err = tf.reduce_mean(rn.output, labels) 
    train = tf.train.GradientDescentOptimizer(learning_rate).minimize(err) 
    # Graph has been built completely. Begin tf.Session() 
    sess = tf.Session() 
    sess.run(tf.initialize_all_variables()) 
    for i in range(1000): 
     # Carry out the training in each iteration 
     # Note that train is an operation here 
     sess.run(train, feed_dict=feed_dict) 
     if i == 0: 
      # Execute `operation` for the first iteration 
      result = sess.run(operation, feed_dict=feed_dict) 
+0

感謝您的回答。對不起,我錯過了我原來的問題中的一個非常重要的信息。 'pseudo_bp'中計算的變量將在'build_net'中使用。 'build_net'也被鏈接到'train'操作。所以我想知道'pseudo_bp'是否仍然會在下面的迭代中運行? – southdoor

+0

我做了你的建議,我得到了一個錯誤,我更新了這個問題,以獲得更多的日誌信息。你可以看一下嗎?謝謝! – southdoor

+0

你給'pseudo_bp'分配了什麼?您可能需要將它分配給'rn.bp' – martianwars

相關問題