2016-11-28 73 views
1

我已經意識到Tensorflow似乎在管理圖形的方式上有一些時髦的東西。由於構建(和重建)模型非常繁瑣,我決定將自定義模型包裝在一個類中,以便我可以在其他地方輕鬆地重新實例化它。Tensorflow如何管理圖形?

當我在訓練和測試代碼時(在原來的地方),它會工作的很好,但是在我加載圖形變量的代碼中,我會得到各種奇怪的錯誤 - 變量重定義和其他一切。這個(從我最後一個類似的問題)提示,一切都被稱爲兩次。

做了跟蹤TON後,它回到了我使用加載的代碼的方式。它正在從一個類,有一個結構,像這樣

class MyModelUser(object): 
    def forecast(self): 
     # .. build the model in the same way as in the training code 
     # load the model checkpoint 
     # call the "predict" function on the model 
     # manipulate the prediction and return it 

然後在一些代碼,使用MyModelUser我有

def test_the_model(self): 
    model_user = MyModelUser() 
    print(model_user.forecast()) # 1 
    print(model_user.forecast()) # 2 

和我(顯然)有望看到兩個預測這時候內使用被稱爲。相反,第一個預測被稱爲按預期工作,但第二個電話扔變量重用的TON ValueError異常的這些中的一個例子是:

ValueError: Variable weight_def/weights already exists, disallowed. Did you mean to set reuse=True in VarScope? 

我設法通過增加一系列平息錯誤試圖/使用get_variable創建變量的塊除外,然後在例外情況下,在範圍上調用reuse_variables,然後在名稱上調用get_variable。這帶來了一套新的嚴重的錯誤,其中之一就是:

tensorflow.python.framework.errors.NotFoundError: Tensor name "weight_def/weights/Adam_1" not found in checkpoint files 

一時心血來潮我說:「如果我的造型建築物代碼移到__init__所以其只內置了一次?」

我的新機型的用戶:

class MyModelUser(object): 
    def __init__(self): 
     # ... build the model in the same way as in the training code 
     # load the model checkpoint 


    def forecast(self): 
     # call the "predict" function on the model 
     # manipulate the prediction and return it 

現在:

def test_the_model(self): 
    model_user = MyModelUser() 
    print(model_user.forecast()) # 1 
    print(model_user.forecast()) # 2 

按預期工作,印花兩大預測沒有錯誤。這使我相信我也可以擺脫可變重用的東西。

我的問題是這樣的:

這是爲什麼解決它?從理論上講,應該在原始預測方法中每次都重新調整圖形,因此它不應該創建多個圖形。即使函數完成後,Tensorflow是否仍然保持圖形?這就是爲什麼將創建代碼移動到__init__工作?這讓我無望地感到困惑。

回答

2

默認情況下,TensorFlow使用首次調用TensorFlow API時創建的單個全局tf.Graph實例。如果您不明確創建tf.Graph,則將在該默認實例中創建所有操作,張量和變量。這意味着您在model_user.forecast()的代碼中的每個調用都會將操作添加到同一個全局圖中,這有點浪費。

有(至少)動作的兩種可能的課程在這裏:

  • 理想的行動是調整你的代碼,以便MyModelUser.__init__()構建整個tf.Graph所有進行預測所需要的操作,而MyModelUser.forecast()只需在現有圖上執行sess.run()調用。理想情況下,您也只能創建一個tf.Session,因爲TensorFlow會在會話中緩存關於圖形的信息,並且執行效率會更高。

  • 的創傷更小—但可能不太有效—變化將是創建一個新的tf.Graph每次調用MyModelUser.forecast()。這是由很多國家是如何在MyModelUser.__init__()方法創建的問題尚不清楚,但你可以不喜歡下面把兩個調用不同的圖表:

    def test_the_model(self): 
        with tf.Graph(): # Create a local graph 
        model_user_1 = MyModelUser() 
        print(model_user_1.forecast()) 
        with tf.Graph(): # Create another local graph 
        model_user_2 = MyModelUser() 
        print(model_user_2.forecast()) 
    
0

TF有一個默認圖表,新的操作等被添加到。當你調用你的函數兩次時,你會將同樣的東西兩次添加到同一個圖中。因此,無論是構建一次圖並多次評估(就像你已經完成的那樣,這也是「正常」方法),或者,如果你想改變一些東西,你可以使用reset_default_graph https://www.tensorflow.org/versions/r0.11/api_docs/python/framework.html#reset_default_graph來重置圖,以便擁有一個新鮮的狀態。