2016-06-12 72 views
5

我相當難以理解圖表如何在張量流中工作以及如何訪問它們。我的直覺是,'with graph:'下面的行會將圖形形成一個單獨的實體。因此,我決定創建一個類,該類將在實例化時構建一個圖並具有可以運行該圖的函數,如下所示;Tensorflow:在課堂上創建圖表並運行它外側

class Graph(object): 

    #To build the graph when instantiated 
    def __init__(self, parameters): 
     self.graph = tf.Graph() 
     with self.graph.as_default(): 
      ... 
      prediction = ... 
      cost  = ... 
      optimizer = ... 
      ... 
    # To launch the graph 
    def launchG(self, inputs): 
     with tf.Session(graph=self.graph) as sess: 
      ... 
      sess.run(optimizer, feed_dict) 
      loss = sess.run(cost, feed_dict) 
      ... 
     return variables 

接下來的步驟是創建將組裝參數傳遞給類,構建圖,然後運行它的主文件;

#Main file 
... 
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... } 

#Building graph 
G = Graph(parameters_dict) 
P = G.launchG(Input) 
... 

這對我來說是非常優雅的,但它並不完全工作(顯然)。事實上,似乎launchG函數無法訪問圖中定義的節點,這給了我錯誤,例如;

---> 26 sess.run(optimizer, feed_dict) 

NameError: name 'optimizer' is not defined 

也許是我的Python(和tensorflow)的理解是太有限了,但我奇怪的印象是,與創建的圖形(G),與該圖運行會話作爲參數應該給訪問到其中的節點,而不需要我給出明確的訪問權限。

任何啓示?

回答

7

節點predictioncost,並optimizer是在方法__init__創建局部變量,它們不能在方法launchG訪問。

最簡單的解決將是它們聲明爲您Graph類的屬性:

class Graph(object): 

    #To build the graph when instantiated 
    def __init__(self, parameters): 
     self.graph = tf.Graph() 
     with self.graph.as_default(): 
      ... 
      self.prediction = ... 
      self.cost  = ... 
      self.optimizer = ... 
      ... 
    # To launch the graph 
    def launchG(self, inputs): 
     with tf.Session(graph=self.graph) as sess: 
      ... 
      sess.run(self.optimizer, feed_dict) 
      loss = sess.run(self.cost, feed_dict) 
      ... 
     return variables 

您也可以使用他們的確切名稱與graph.get_tensor_by_namegraph.get_operation_by_name檢索圖的節點。