2016-03-24 148 views
3

只有一個共享變量陣列的一部分欲執行以下操作:計算梯度爲在Theano

import theano, numpy, theano.tensor as T 

a = T.fvector('a') 

w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX)) 
w_sub = w[1] 

b = T.sum(a * w) 

grad = T.grad(b, w_sub) 

這裏,w_sub是例如W [1],但我不想要顯式寫出來b在函數w_sub。儘管經歷了http://deeplearning.net/software/theano/tutorial/faq_tutorial.html和其他相關問題,我無法解決它。

這只是爲了向你展示我的問題。其實,我真正想做的是與千層麪稀疏卷積。權重矩陣中的零條目不需要更新,因此不需要計算w的這些條目的梯度。

親切問候,並提前謝謝!

的Jeroen

PS:現在這是完整的錯誤消息:

Traceback (most recent call last): 
    File "D:/Jeroen/Project_Lasagne_General/test_script.py", line 9, in <module> 
    grad = T.grad(b, w_sub) 
    File "C:\Anaconda2\lib\site-packages\theano\gradient.py", line 545, in grad 
    handle_disconnected(elem) 
    File "C:\Anaconda2\lib\site-packages\theano\gradient.py", line 532, in handle_disconnected 
    raise DisconnectedInputError(message) 
theano.gradient.DisconnectedInputError: grad method was asked to compute the gradient with respect to a variable that is not part of the computational graph of the cost, or is used only by a non-differentiable operator: Subtensor{int64}.0 
Backtrace when the node is created: 
    File "D:/Jeroen/Project_Lasagne_General/test_script.py", line 6, in <module> 
    w_sub = w[1] 

回答

2

當theano編譯圖表,只看到變量在圖形如所明確定義。在您的示例中,w_sub未明確用於計算b,因此不是計算圖的一部分。

使用帶以下代碼的theano打印庫,您可以在此 graph vizualization上看到確實w_sub不是b圖的一部分。

import theano 
import theano.tensor as T 
import numpy 
import theano.d3viz as d3v 

a = T.fvector('a') 
w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX)) 
w_sub = w[1] 
b = T.sum(a * w) 

o = b, w_sub 

d3v.d3viz(o, 'b.html') 

爲了解決這個問題,就需要在b計算明確使用w_sub

然後你就能夠計算b WRT w_sub的梯度和更新共享變量的值,如下面的例子:

import theano 
import theano.tensor as T 
import numpy 


a = T.fvector('a') 
w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX)) 
w_sub = w[1] 
b = T.sum(a * w_sub) 
grad = T.grad(b, w_sub) 
updates = [(w, T.inc_subtensor(w_sub, -0.1*grad))] 

f = theano.function([a], b, updates=updates, allow_input_downcast=True) 

f(numpy.arange(10)) 
+0

上正在發生的事情一個很好的解釋。謝謝。 –