2016-12-01 74 views
0

我定義了以下類:如何更改Theano中共享變量的值?

class test: 

    def __init__(self): 
     self.X = theano.tensor.dmatrix('x') 
     self.W = theano.shared(value=numpy.zeros((5, 2), dtype=theano.config.floatX), name='W', borrow=True) 
     self.out = theano.dot(self.X, self.W) 

    def eval(self, X): 
     _eval = theano.function([self.X], self.out) 
     return _eval(X) 

之後,我試圖改變W矩陣的值,並用新的值來計算。我這樣做以下列方式:

m = test() 
W = np.transpose(np.array([[1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 2.0, 3.0, 3.0, 3.0]])) 
dn.W = theano.shared(value=W, name='W', borrow=True) 
dn.eval(X) 

,我得到對應於已經在__init__設置(所有元素都爲零)的的W價值的結果。

爲什麼類沒有看到我在初始化後明確設置的新值W

回答

1

您剛剛爲python變量dn.W創建了一個新的共享變量,但theano的內部計算圖仍然鏈接到舊的共享變量。

改變存儲在現有的共享變量值:

W = np.transpose(np.array([[1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 2.0, 3.0, 3.0, 3.0]])) 
dn.W.set_value(W)) 

注意如果你想使用結果從一個函數調用來更新共享變量,更好的辦法是使用theano.functionupdates說法。如果共享變量存儲在GPU中,這將消除不必要的內存傳輸。