您可以使用符號條件操作。 Theano有兩個:switch
和ifelse
。 switch
以元素方式執行,而ifelse
更像傳統條件。有關更多信息,請參閱documentation。
下面是一個僅在成本爲正值時更新參數的示例。
import numpy
import theano
import theano.tensor as tt
def compile(input_size, hidden_size, output_size, learning_rate):
w_h = theano.shared(numpy.random.standard_normal((input_size, hidden_size))
.astype(theano.config.floatX), name='w_h')
b_h = theano.shared(numpy.random.standard_normal((hidden_size,))
.astype(theano.config.floatX), name='b_h')
w_y = theano.shared(numpy.random.standard_normal((hidden_size, output_size))
.astype(theano.config.floatX), name='w_y')
b_y = theano.shared(numpy.random.standard_normal((output_size,))
.astype(theano.config.floatX), name='b_y')
x = tt.matrix()
z = tt.vector()
h = tt.tanh(theano.dot(x, w_h) + b_h)
y = theano.dot(h, w_y) + b_y
c = tt.sum(y - z)
updates = [(p, p - tt.switch(tt.gt(c, 0), learning_rate * tt.grad(cost=c, wrt=p), 0))
for p in (w_h, b_h, w_y, b_y)]
return theano.function([x, z], outputs=c, updates=updates)
def main():
f = compile(input_size=3, hidden_size=2, output_size=4, learning_rate=0.01)
main()
在這種情況下,無論是switch
或ifelse
可以使用,但switch
就是在這樣的情況下,一般可取的,因爲ifelse
沒有出現在整個Theano框架內,以及支持和需要特殊的進口。
感謝您的答案:)。真的很有幫助 –