2016-04-23 96 views
5

我正在使用keras 1.0.1我試圖在LSTM上添加關注層。這是我迄今爲止的,但它不起作用。Keras關注層超過LSTM

input_ = Input(shape=(input_length, input_dim)) 
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_) 
att = TimeDistributed(Dense(1)(lstm)) 
att = Reshape((-1, input_length))(att) 
att = Activation(activation="softmax")(att) 
att = RepeatVector(self.HID_DIM)(att) 
merge = Merge([att, lstm], "mul") 
hid = Merge("sum")(merge) 

last = Dense(self.HID_DIM, activation="relu")(hid) 

網絡應該在輸入序列上應用LSTM。然後,LSTM的每個隱藏狀態都應輸入一個完全連接的層,通過該層完成Softmax的應用。 softmax會針對每個隱藏維度進行復制,然後乘以LSTM隱藏狀態。然後得到的向量應該平均。

編輯:這個編譯,但我不知道它是否做我認爲應該做的。

input_ = Input(shape=(input_length, input_dim)) 
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_) 
att = TimeDistributed(Dense(1))(lstm) 
att = Flatten()(att) 
att = Activation(activation="softmax")(att) 
att = RepeatVector(self.HID_DIM)(att) 
att = Permute((2,1))(att) 
mer = merge([att, lstm], "mul") 
hid = AveragePooling1D(pool_length=input_length)(mer) 
hid = Flatten()(hid) 
+0

嗨@iamii與關注網絡有成功嗎?目前我正在嘗試同樣的事情。 – Nacho

+0

看看這個LSTM上的Attention實現:https://github.com/philipperemy/keras-attention-mechanism它應該適用於你的例子。 –

回答