2016-09-18 153 views
11

我非常喜歡使用Keras API的簡單構建強化學習模型。不幸的是,我無法提取關於權重的輸出的梯度(不是錯誤)。我發現下面的代碼,執行類似功能(Saliency maps of neural networks (using Keras)使用Keras獲得模型輸出w.r.t權重的梯度

get_output = theano.function([model.layers[0].input],model.layers[-1].output,allow_input_downcast=True) 
fx = theano.function([model.layers[0].input] ,T.jacobian(model.layers[-1].output.flatten(),model.layers[0].input), allow_input_downcast=True) 
grad = fx([trainingData]) 

關於如何相對於權重爲每個層計算模型輸出的梯度任何想法,將不勝感激。

+0

你有任何進展嗎?我收到以下錯誤使用類似的顯着功能:https://github.com/fchollet/keras/issues/1777#issuecomment-250040309 – ssierral

+0

我還沒有取得任何Keras成功。但是,我已經能夠使用tensorflow來做到這一點。 –

+0

https://github.com/yanpanlau/DDPG-Keras-Torcs CriticNetwork.py使用tensorflow後端來計算梯度,同時使用Keras來實際構建網絡體系結構 –

回答

14

要使用Keras獲得模型輸出相對於權重的梯度,您必須使用Keras後端模塊。我創建了這個簡單的例子來說明如何操作:

from keras.models import Sequential 
from keras.layers import Dense, Activation 
from keras import backend as k 


model = Sequential() 
model.add(Dense(12, input_dim=8, init='uniform', activation='relu')) 
model.add(Dense(8, init='uniform', activation='relu')) 
model.add(Dense(1, init='uniform', activation='sigmoid')) 
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) 

要計算梯度,我們首先需要找到輸出張量。對於模型的輸出(我最初提出的問題),我們只需調用model.output。我們也可以通過調用model.layers [指數] .OUTPUT

outputTensor = model.output #Or model.layers[index].output 

然後,我們需要選擇在對於梯度變量找到的其他層輸出的梯度。

listOfVariableTensors = model.trainable_weights 
    #or variableTensors = model.trainable_weights[0] 

我們現在可以計算出梯度。它很容易如下:

gradients = k.gradients(outputTensor, listOfVariableTensors) 

要真正運行給定輸入的梯度,我們需要使用一點Tensorflow。

trainingExample = np.random.random((1,8)) 
sess = tf.InteractiveSession() 
sess.run(tf.initialize_all_variables()) 
evaluated_gradients = sess.run(gradients,feed_dict={model.input:trainingExample}) 

而那就是它!

+2

我運行了這段代碼(以theano作爲後端)並出現以下錯誤:「TypeError:成本必須是標量。」我想知道,這可以通過後端不可知的方法來實現嗎? –

+0

Matt S,如何計算漸變而不指定sess.run中的標籤? –

+0

我正在使用漸變w.r.t輸入。如果你想漸變w.r.t損失,那麼你需要定義損失函數,用loss_fn替換k.gradients中的outputTensor,然後將標籤傳遞給feed字典。 –